在使用 PyTorch 构建和调试神经网络时,了解模型的结构非常重要。PyTorch 提供了 children()
和 modules()
两个方法来遍历和查看模型的子模块。本文将详细介绍这两个方法的区别,并演示如何使用它们来输出模型的网络结构。
什么是子模块?
在 PyTorch 中,子模块指的是一个 torch.nn.Module
实例中包含的其他 torch.nn.Module
实例。子模块是构建复杂神经网络的基础单元。例如,一个卷积神经网络可能包含多个卷积层、批量归一化层和线性层,每个层都是一个子模块。
children()
方法
children()
方法返回模型的直接子模块,不包括嵌套的子模块。它只遍历模型的第一层子模块,非常适合查看模型的基本组成部分。
import torch.nn as nn
# 定义子模块
class SubModule(nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.fc = nn.Linear(10, 5)
def forward(self, x):
return self.fc(x)
# 定义主模块
class MainModule(nn.Module):
def __init__(self):
super(MainModule, self).__init__()
self.submodule = SubModule()
self.conv = nn.Conv2d(1, 20, 5)
self.bn = nn.BatchNorm2d(20)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = x.view(x.size(0), -1) # 展平
x = self.submodule(x)
return x
# 创建模型实例
model = MainModule()
# 打印直接子模块
print("Direct children:")
for child in model.children():
print(child)
#输出:
Direct children:
SubModule(
(fc): Linear(in_features=10, out_features=5, bias=True)
)
Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
从输出可以看出,children()
方法仅列出了 MainModule
的直接子模块:SubModule
、Conv2d
和 BatchNorm2d
。
需要注意的一点是children()返回的是一个生成器,需要用for循环遍历输出,否则直接输出的是生成器的地址
modules()
方法
modules()
方法返回模型本身以及所有嵌套的子模块。它会递归地遍历模型的所有层次,非常适合查看完整的网络结构。
#将上述代码的输出部分换成:
# 打印完整的网络结构
print("Model structure:")
for module in model.modules():
print(module)
#结果为:
Model structure:
MainModule(
(submodule): SubModule(
(fc): Linear(in_features=10, out_features=5, bias=True)
)
(conv): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
(bn): BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
SubModule(
(fc): Linear(in_features=10, out_features=5, bias=True)
)
Linear(in_features=10, out_features=5, bias=True)
Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
BatchNorm2d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
使用 modules()
方法,我们可以看到 MainModule
的完整结构,包括所有嵌套的子模块和子模块中的模块。
总结
children()
:返回模型的直接子模块,不包括嵌套的子模块。适合快速查看模型的主要组成部分。modules()
:返回模型本身及其所有嵌套的子模块,进行递归遍历。适合查看模型的完整结构。