1 介绍
- 在 PyTorch 中,
nn.ModuleDict
是一个方便的容器,用于存储一组子模块(即 nn.Module
对象)的字典
- 这个容器主要用于动态地管理多个模块,并通过键来访问它们,类似于 Python 的字典
2 特点
- 组织性
nn.ModuleDict
提供了一种将多个模块有序组织在一起的方法。
- 这有助于让代码更加结构化,易于理解和维护
- 动态操作
- 可以像操作普通字典那样添加或删除模块
- 例如使用
module_dict['key'] = module
添加模块,使用 del module_dict['key']
删除模块
- 自动参数注册
- 当将模块添加到
ModuleDict
中时,它们的参数会自动注册到整个网络中,确保在模型训练时这些参数可以被识别和更新
3 例子
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.layers = nn.ModuleDict({
'linear': nn.Linear(10, 20),
'activation': nn.R