在构建复杂的神经网络(如ResNet, DenseNet)时,如果仅仅像堆积木一样一层层手写 self.layer1, self.layer2… 代码会变得非常冗长且难以维护。
PyTorch 提供了三种特殊的容器(Containers),帮助我们更优雅地组织和管理网络层。它们分别是:nn.Sequential、nn.ModuleList 和 nn.ModuleDict。

模型容器Container
1.nn.Sequential
功能: 是nn.Module的容器,用于按顺序包装一组网络层。
应用场景: 顺序流式网络结构
还是以LeNet为例,我们将LeNet分成features和classifier两部分,每个部分都是一个sequential:

import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet(nn.Module):
# V1
def __init__(self, num_classes=10):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, num_classes)
)
self.flatten = nn.Flatten()
def forward(self, x):
x = self.featurs(x)
x = self.flatten(x)
x = self.classifier(x)
return x
'''
class LeNet(nn.Module):
# V2 —— 全程流式
def __init__(self, num_classes=10):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(),
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, num_classes)
)
def forward(self, x):
return self.net(x)
'''
这种写法虽然简单,但网络层的名字会自动按顺序编一个号作为key 0, 1, 2。当网络很深时(比如第 50 层),调试时看到 KeyError: '48' 很难反应过来这层具体是做什么的,可读性较差。

进阶用法:使用OrderedDict命名
Sequential也有相应的应对方法,即为每一层网络命名,具体代码如下所示:
class LeNet(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.features = nn.Sequential(OrderedDict({
'conv1': nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),
'relu1': nn.ReLU(inplace=True),
'pool1': nn.MaxPool2d(kernel_size=2, stride=2),
'conv2': nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
'relu2': nn.ReLU(inplace=True),
'pool2': nn.MaxPool2d(kernel_size=2, stride=2),
}))
self.classifier = nn.Sequential(OrderedDict({
'fc1': nn.Linear(16*5*5, 120),
'relu3': nn.ReLU(),
'fc2': nn.Linear(120, 84),
'relu4': nn.ReLU(inplace=True),
'fc3': nn.Linear(84, classes)
}))
def forward(self, x):
x = self.featurs(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x
与原来不同的地方就是,构建了一个OrderedDict字典来存放键值对,key就是每一层网络的名字,value就是具体的网络层实现,查找和调试变得非常直观。看一下此时的module属性内部:

✅ 特点总结
- 顺序性: 严格按照定义的顺序执行。
- 自动化: 容器自带
forward()函数。在模型前向传播时,不需要手动写x = layer1(x),直接调用self.features(x)即可,它会自动循环执行内部所有层。
2.nn.ModuleList
功能: 存放子模块的“列表容器”,它像Python原生的list,用于存储包装一组网络层。但它比普通list多了一个关键功能——参数注册。
❓ 为什么不用普通的 Python list?
PyTorch 的 nn.Module 无法识别这些层是模型的一部分,导致这些层的参数不会被加到 model.parameters() 中,优化器也无法更新它们。必须使用 nn.ModuleList 进行包装。
主要方法:
append():在ModuleList末尾添加网络层。extend():拼接两个ModuleList。insert():在指定位置插入网络层。
应用场景: 适合构建大量重复的层(如 RNN 的时间步、ResNet 的多个 Block)。
一个完整的ResNet模型构建代码如下,具体ModuleList使用见class ResBlock
import torch
from torch import nn
from torch.nn import functional as F
class Residual(nn.Module):
'''
内部卷积Sequential,外部残差做判断if use_conv1x1
'''
def __init__(self, in_channel, out_channel, use_conv1x1=False, stride=1):
super().__init__()
self.Conv2D = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, stride=stride),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True),
nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channel))
if use_conv1x1:
self.shortcut = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride)
self.bn = nn.BatchNorm2d(out_channel)
else:
self.shortcut = None
def forward(self, x):
y = self.Conv2D(x)
if self.shortcut:
x = self.shortcut(x)
x = self.bn(x) # If using 1x1 conv, also apply BN
y = y + x
return F.relu(y)
class ResBlock(nn.Module):
"""
容器类型:nn.ModuleList,仅将子模块注册到模型中,不定义前向传播逻辑,灵活性更高(可自定义控制流)
nn.ModuleList是存放子模块的“列表容器”,PyTorch能追踪里面的参数
"""
def __init__(self, in_channel, out_channel, num_residuals, first_block=False):
super().__init__()
self.res_blocks = nn.ModuleList()
for i in range(num_residuals):
# 从第二个大残差结构开始, 结构中的第一个残差块都是下采样映射残差块
if i == 0 and not first_block:
self.res_blocks.append(Residual(in_channel, out_channel, use_conv1x1=True, stride=2))
else:
# 后续残差块:通道数不变
self.res_blocks.append(Residual(out_channel, out_channel))
def forward(self, x):
for res in self.res_blocks: # 显式遍历 ModuleList 并调用每个子模块
x = res(x)
return x
class ResNet18(nn.Module):
def __init__(self, num_classes, in_channel):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64), nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
ResBlock(64, 64, 2, first_block=True),
ResBlock(64, 128, 2),
ResBlock(128, 256, 2),
ResBlock(256, 512, 2),
nn.AdaptiveAvgPool2d((1,1)),
nn.Flatten(),
nn.Linear(512, num_classes))
def forward(self, x):
return self.net(x)
✅ 特点总结
- 注册机制: 解决了 Python 原生列表无法被 PyTorch 识别为模型参数的问题(即自动将列表内的层注册到
_modules中)。 - 灵活性: 容器不包含
forward()函数。这意味着它不限制数据流动的顺序,你可以在forward中使用for循环、索引访问(self.layers[2](x)),甚至跳过某些层。 - 迭代性: 本质是一个可迭代对象,非常适合配合列表推导式使用。
3.nn.ModuleDict
功能: 它像 Python 原生的 dict,允许通过键值对(Key-Value)的方式存储和访问网络层。同样,它也能自动注册参数。
主要方法:
clear():清空ModuleDictitems():返回可迭代的键值对keys():返回字典的键values():返回字典的值pop():返回一对键值,并从字典中删除
应用场景: 因为键值对可以索引的特性,可用于选择网络层。适合构建动态网络或可选择的网络分支。比如根据参数决定使用哪种激活函数或卷积层。
class ModuleDict(nn.Module):
def __init__(self):
super(ModuleDict, self).__init__()
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
})
self.activations = nn.ModuleDict({
'relu': nn.ReLU(),
'prelu': nn.PReLU()
})
def forward(self, x, choice_key, act_key):
x = self.choices[choice_key](x)
x = self.activations[act_key](x)
return x
net = ModuleDict()
Digital_Img = torch.randn((16, 1, 28, 28))
output = net(Digital_Img, 'conv', 'relu') # 选择使用conv和relu
print(output)
✅ 特点总结
- 索引性: 可以通过字符串 Key 精准访问特定的网络层,代码可读性极高。
- 选择性: 容器不包含
forward()函数。它常用于实现“多选一”的逻辑(Switch-Case 效果),非常适合用于算法研究中对比不同模块的效果。 - 状态管理: 与
ModuleList一样,它负责将字典内的所有子模块注册到模型中,确保参数不会丢失。
4.总结与对比
| 容器类型 | 核心特性 | 是否自带 forward | 典型应用场景 |
|---|---|---|---|
nn.Sequential | 顺序性 | ✅ 是 | 像搭积木一样构建标准的、按顺序执行的网络块(Block) |
nn.ModuleList | 迭代性 | ❌ 否 (需手动写) | 需要大量重复构建层,或在 forward 中需要灵活控制循环逻辑时 |
nn.ModuleDict | 索引性 | ❌ 否 (需手动写) | 需要根据参数动态选择网络分支,或管理非线性的网络结构时 |
2173

被折叠的 条评论
为什么被折叠?



