在 PyTorch 中,__init__()
和 forward()
是定义神经网络模型时最常用的两个函数。它们分别负责模型的初始化和前向传播计算。以下是它们的联系、区别以及使用方式的详细说明:
1. __init__()
函数
作用
__init__()
是类的构造函数,用于初始化模型的参数和结构。- 在这个函数中,你需要定义模型的层(如卷积层、全连接层、Dropout 等)以及它们的参数。
- 它会在创建模型实例时自动调用。
具体实现
- 在
SAGE
类中,__init__()
定义了模型的层结构:- 使用
nn.ModuleList()
存储了三个SAGEConv
层。 - 定义了一个
Dropout
层,用于在训练时随机丢弃部分神经元,防止过拟合。
- 使用
代码示例
def __init__(self, in_feats, n_hidden, n_classes):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
self.dropout = nn.Dropout(0.5)
参数说明
in_feats
:输入特征的维度。n_hidden
:隐藏层的维度。n_classes
:输出类别的数量。
2. forward()
函数
作用
forward()
定义了模型的前向传播过程,即如何从输入数据计算输出结果。- 它接收输入数据(如图结构
sg
和节点特征x
),并通过模型的层结构逐步计算输出。 - 在 PyTorch 中,
forward()
是实际执行模型计算的地方。
具体实现
- 在
SAGE
类中,forward()
实现了以下逻辑:- 将输入特征
x
赋值给h
。 - 依次通过三个
SAGEConv
层,每层之间使用ReLU
激活函数和Dropout
。 - 最后一层不使用
ReLU
和Dropout
,直接返回输出。
- 将输入特征
代码示例
def forward(self, sg, x):
h = x
for l, layer in enumerate(self.layers):
h = layer(sg, h)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
return h
参数说明
sg
:输入的图结构(DGLGraph 对象)。x
:输入的节点特征(Tensor 对象)。
3. __init__()
和 forward()
的联系与区别
联系
__init__()
和forward()
共同定义了模型的结构和行为。__init__()
中定义的层(如SAGEConv
和Dropout
)会在forward()
中被调用,用于计算输出。
区别
特性 | __init__() | forward() |
---|---|---|
调用时机 | 在创建模型实例时调用(初始化阶段)。 | 在模型推理或训练时调用(前向传播阶段)。 |
主要任务 | 定义模型的结构和参数。 | 定义如何从输入计算输出。 |
是否显式调用 | 不需要显式调用,由 Python 自动调用。 | 不需要显式调用,由 PyTorch 自动调用。 |
参数 | 接收模型的超参数(如输入维度、隐藏层维度)。 | 接收输入数据(如图结构、节点特征)。 |
4. 如何使用 __init__()
和 forward()
步骤 1:定义模型
- 在
__init__()
中定义模型的结构。 - 在
forward()
中定义模型的计算逻辑。
步骤 2:创建模型实例
- 通过传入超参数创建模型实例。
步骤 3:调用模型
- 直接调用模型实例(如
model(sg, x)
),PyTorch 会自动调用forward()
函数。
代码示例
# 定义模型
model = SAGE(in_feats=128, n_hidden=64, n_classes=10)
# 输入数据
sg = dgl.graph(...) # 图结构
x = torch.randn(100, 128) # 节点特征
# 前向传播
output = model(sg, x) # 自动调用 forward()
print(output.shape) # 输出形状
5. 总结
__init__()
负责初始化模型的结构和参数。forward()
负责定义模型的计算逻辑。- 两者共同构成了 PyTorch 模型的核心部分。
- 使用时,只需定义好模型的结构和计算逻辑,PyTorch 会自动处理前向传播的过程。