Pytorch深入浅出(三)之网络模型的构建(下)

在构建复杂的神经网络(如ResNet, DenseNet)时,如果仅仅像堆积木一样一层层手写 self.layer1, self.layer2… 代码会变得非常冗长且难以维护。
PyTorch 提供了三种特殊的容器(Containers),帮助我们更优雅地组织和管理网络层。它们分别是:nn.Sequentialnn.ModuleListnn.ModuleDict

在这里插入图片描述

模型容器Container

1.nn.Sequential
功能: 是nn.Module的容器,用于按顺序包装一组网络层。
应用场景: 顺序流式网络结构
还是以LeNet为例,我们将LeNet分成features和classifier两部分,每个部分都是一个sequential:
在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
	# V1
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
        nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2)
        )
		self.classifier = nn.Sequential(
		nn.Linear(16*5*5, 120),
		nn.ReLU(),
		nn.Linear(120, 84),
		nn.ReLU(),
		nn.Linear(84, num_classes)
		)
        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.featurs(x)
        x = self.flatten(x)
        x = self.classifier(x)
        return x

    '''
class LeNet(nn.Module):
	# V2 —— 全程流式
    def __init__(self, num_classes=10):
        super().__init__()
		self.net = nn.Sequential(
        nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
		nn.Linear(16*5*5, 120),
		nn.ReLU(),
		nn.Linear(120, 84),
		nn.ReLU(),
		nn.Linear(84, num_classes)
		)
    def forward(self, x):
        return self.net(x)
	'''

这种写法虽然简单,但网络层的名字会自动按顺序编一个号作为key 0, 1, 2。当网络很深时(比如第 50 层),调试时看到 KeyError: '48' 很难反应过来这层具体是做什么的,可读性较差。

在这里插入图片描述
进阶用法:使用OrderedDict命名
Sequential也有相应的应对方法,即为每一层网络命名,具体代码如下所示:

class LeNet(nn.Module):
	def __init__(self, num_classes):
		super().__init__()
		self.features = nn.Sequential(OrderedDict({
		'conv1': nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),
		'relu1': nn.ReLU(inplace=True),
		'pool1': nn.MaxPool2d(kernel_size=2, stride=2),
		'conv2': nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
		'relu2': nn.ReLU(inplace=True),
		'pool2': nn.MaxPool2d(kernel_size=2, stride=2),
		}))
		
		 self.classifier = nn.Sequential(OrderedDict({
		 'fc1': nn.Linear(16*5*5, 120),
		 'relu3': nn.ReLU(),
		 'fc2': nn.Linear(120, 84),
		 'relu4': nn.ReLU(inplace=True),
		 'fc3': nn.Linear(84, classes)
		}))
		
    def forward(self, x):
        x = self.featurs(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x

与原来不同的地方就是,构建了一个OrderedDict字典来存放键值对,key就是每一层网络的名字,value就是具体的网络层实现,查找和调试变得非常直观。看一下此时的module属性内部:
在这里插入图片描述
✅ 特点总结

  • 顺序性: 严格按照定义的顺序执行。
  • 自动化: 容器自带 forward() 函数。在模型前向传播时,不需要手动写 x = layer1(x),直接调用 self.features(x) 即可,它会自动循环执行内部所有层。

2.nn.ModuleList
功能: 存放子模块的“列表容器”,它像Python原生的list,用于存储包装一组网络层。但它比普通list多了一个关键功能——参数注册
❓ 为什么不用普通的 Python list?
PyTorch 的 nn.Module 无法识别这些层是模型的一部分,导致这些层的参数不会被加到 model.parameters() 中,优化器也无法更新它们。必须使用 nn.ModuleList 进行包装。

主要方法:

  • append():在ModuleList末尾添加网络层。
  • extend():拼接两个ModuleList。
  • insert():在指定位置插入网络层。

应用场景: 适合构建大量重复的层(如 RNN 的时间步、ResNet 的多个 Block)。
一个完整的ResNet模型构建代码如下,具体ModuleList使用见class ResBlock

import torch
from torch import nn
from torch.nn import functional as F

class Residual(nn.Module):
    '''
    内部卷积Sequential,外部残差做判断if use_conv1x1
    '''
    def __init__(self, in_channel, out_channel, use_conv1x1=False, stride=1):
        super().__init__()
        self.Conv2D = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, stride=stride),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channel))
        if use_conv1x1:
            self.shortcut = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride)
            self.bn = nn.BatchNorm2d(out_channel)
        else:
            self.shortcut = None
  
    def forward(self, x):
        y = self.Conv2D(x)
        if self.shortcut:
            x = self.shortcut(x)
            x = self.bn(x)  # If using 1x1 conv, also apply BN
        y = y + x
        return F.relu(y)

class ResBlock(nn.Module):
    """
    容器类型:nn.ModuleList,仅将子模块注册到模型中,不定义前向传播逻辑,灵活性更高(可自定义控制流)
    nn.ModuleList是存放子模块的“列表容器”,PyTorch能追踪里面的参数
    """
    def __init__(self, in_channel, out_channel, num_residuals, first_block=False):
        super().__init__()
        self.res_blocks = nn.ModuleList()
  
        for i in range(num_residuals):
            # 从第二个大残差结构开始, 结构中的第一个残差块都是下采样映射残差块
            if i == 0 and not first_block:    
                self.res_blocks.append(Residual(in_channel, out_channel, use_conv1x1=True, stride=2))
            else:
                # 后续残差块:通道数不变
                self.res_blocks.append(Residual(out_channel, out_channel))
  
    def forward(self, x):
        for res in self.res_blocks:   # 显式遍历 ModuleList 并调用每个子模块
            x = res(x)
        return x

class ResNet18(nn.Module):
    def __init__(self, num_classes, in_channel):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channel, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64), nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            ResBlock(64, 64, 2, first_block=True),
            ResBlock(64, 128, 2),
            ResBlock(128, 256, 2),
            ResBlock(256, 512, 2),
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
            nn.Linear(512, num_classes))
    def forward(self, x):
        return self.net(x)

✅ 特点总结

  • 注册机制: 解决了 Python 原生列表无法被 PyTorch 识别为模型参数的问题(即自动将列表内的层注册到 _modules 中)。
  • 灵活性: 容器不包含 forward() 函数。这意味着它不限制数据流动的顺序,你可以在 forward 中使用 for 循环、索引访问(self.layers[2](x)),甚至跳过某些层。
  • 迭代性: 本质是一个可迭代对象,非常适合配合列表推导式使用。

3.nn.ModuleDict
功能: 它像 Python 原生的 dict,允许通过键值对(Key-Value)的方式存储和访问网络层。同样,它也能自动注册参数。

主要方法:

  • clear():清空ModuleDict
  • items():返回可迭代的键值对
  • keys():返回字典的键
  • values():返回字典的值
  • pop():返回一对键值,并从字典中删除

应用场景: 因为键值对可以索引的特性,可用于选择网络层。适合构建动态网络可选择的网络分支。比如根据参数决定使用哪种激活函数或卷积层。

class ModuleDict(nn.Module):
    def __init__(self):
        super(ModuleDict, self).__init__()
        self.choices = nn.ModuleDict({
            'conv': nn.Conv2d(10, 10, 3),
            'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict({
            'relu': nn.ReLU(),
            'prelu': nn.PReLU()
        })

    def forward(self, x, choice_key, act_key):
        x = self.choices[choice_key](x)
        x = self.activations[act_key](x)
        return x

net = ModuleDict()
Digital_Img = torch.randn((16, 1, 28, 28))
output = net(Digital_Img, 'conv', 'relu') # 选择使用conv和relu
print(output)

✅ 特点总结

  • 索引性: 可以通过字符串 Key 精准访问特定的网络层,代码可读性极高。
  • 选择性: 容器不包含 forward() 函数。它常用于实现“多选一”的逻辑(Switch-Case 效果),非常适合用于算法研究中对比不同模块的效果。
  • 状态管理:ModuleList 一样,它负责将字典内的所有子模块注册到模型中,确保参数不会丢失。

4.总结与对比

容器类型核心特性是否自带 forward典型应用场景
nn.Sequential顺序性✅ 是像搭积木一样构建标准的、按顺序执行的网络块(Block)
nn.ModuleList迭代性❌ 否 (需手动写)需要大量重复构建层,或在 forward 中需要灵活控制循环逻辑时
nn.ModuleDict索引性❌ 否 (需手动写)需要根据参数动态选择网络分支,或管理非线性的网络结构时
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值