深入理解 PyTorch 中的 ‘children()‘ 和 ‘modules()‘

在使用 PyTorch 构建和调试神经网络时,了解模型的结构非常重要。PyTorch 提供了 children()modules() 两个方法来遍历和查看模型的子模块。本文将详细介绍这两个方法的区别,并演示如何使用它们来输出模型的网络结构。

什么是子模块?

在 PyTorch 中,子模块指的是一个 torch.nn.Module 实例中包含的其他 torch.nn.Module 实例。子模块是构建复杂神经网络的基础单元。例如,一个卷积神经网络可能包含多个卷积层、批量归一化层和线性层,每个层都是一个子模块

children() 方法

children() 方法返回模型的直接子模块,不包括嵌套的子模块。它只遍历模型的第一层子模块,非常适合查看模型的基本组成部分。

import torch.nn as nn

# 定义子模块
class SubModule(nn.Module):
    def __init__(self):
        super(SubModule, self).__init__()
        self.fc = nn.Linear(10, 5)

    def forward(self, x):
        return self.fc(x)

# 定义主模块
class MainModule(nn.Module):
    def __init__(self):
        super(MainModule, self).__init__()
        self.submodule = SubModule()
        self.conv = nn.Conv2d(1, 20, 5)
        self.bn = nn.BatchNorm2d(20)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = x.view(x.size(0), -1)  # 展平
        x = self.submodule(x)
        return x

# 创建模型实例
model = MainModule()

# 打印直接子模块
print("Direct children:")
for child in model.children():
    print(child)

#输出:
Direct children:
SubModule(
  (fc): Linear(in_features=10, out_features=5, bias=True)
)
Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

从输出可以看出,children() 方法仅列出了 MainModule 的直接子模块:SubModuleConv2dBatchNorm2d

需要注意的一点是children()返回的是一个生成器,需要用for循环遍历输出,否则直接输出的是生成器的地址

modules() 方法

modules() 方法返回模型本身以及所有嵌套的子模块。它会递归地遍历模型的所有层次,非常适合查看完整的网络结构。

#将上述代码的输出部分换成:
# 打印完整的网络结构
print("Model structure:")
for module in model.modules():
    print(module)

#结果为:
Model structure:
MainModule(
  (submodule): SubModule(
    (fc): Linear(in_features=10, out_features=5, bias=True)
  )
  (conv): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (bn): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
SubModule(
  (fc): Linear(in_features=10, out_features=5, bias=True)
)
Linear(in_features=10, out_features=5, bias=True)
Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

使用 modules() 方法,我们可以看到 MainModule 的完整结构,包括所有嵌套的子模块和子模块中的模块。

总结

  • children():返回模型的直接子模块,不包括嵌套的子模块。适合快速查看模型的主要组成部分。
  • modules():返回模型本身及其所有嵌套的子模块,进行递归遍历。适合查看模型的完整结构。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值