torch.nn.Module
是 PyTorch 中一个重要的基类,用于构建神经网络模型。它提供了一种方便的方式来组织和管理模型参数、定义前向传播等功能。继承自 torch.nn.Module
的类可以被视为一个可训练的参数集合,可以包含其他模块,从而形成层次化的模型结构。
文章目录
一、关键功能和属性
1.1 参数管理
torch.nn.Module
可以追踪并管理所有注册的参数。通过 parameters()
方法,可以方便地获取模型中的所有参数。
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 5)
model = MyModel()
print(list(model.parameters()))
1.2 子模块管理
通过将其他 torch.nn.Module
的实例注册为当前模块的属性,可以形成层次化的模型结构。这使得模型可以以更模块化的方式进行定义。
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
model = MyModel()
1.3 前向传播定义
在继承 torch.nn.Module
的子类中,需要实现 forward
方法来定义模型的前向传播过程。