在PyTorch中,def __init__() 和 def forward()函数的联系与区别,以及怎么使用的

在 PyTorch 中,__init__()forward() 是定义神经网络模型时最常用的两个函数。它们分别负责模型的初始化和前向传播计算。以下是它们的联系、区别以及使用方式的详细说明:


1. __init__() 函数

作用
  • __init__() 是类的构造函数,用于初始化模型的参数和结构。
  • 在这个函数中,你需要定义模型的层(如卷积层、全连接层、Dropout 等)以及它们的参数。
  • 它会在创建模型实例时自动调用。
具体实现
  • SAGE 类中,__init__() 定义了模型的层结构:
    • 使用 nn.ModuleList() 存储了三个 SAGEConv 层。
    • 定义了一个 Dropout 层,用于在训练时随机丢弃部分神经元,防止过拟合。
代码示例
def __init__(self, in_feats, n_hidden, n_classes):
    super().__init__()
    self.layers = nn.ModuleList()
    self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, "mean"))
    self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, "mean"))
    self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, "mean"))
    self.dropout = nn.Dropout(0.5)
参数说明
  • in_feats:输入特征的维度。
  • n_hidden:隐藏层的维度。
  • n_classes:输出类别的数量。

2. forward() 函数

作用
  • forward() 定义了模型的前向传播过程,即如何从输入数据计算输出结果。
  • 它接收输入数据(如图结构 sg 和节点特征 x),并通过模型的层结构逐步计算输出。
  • 在 PyTorch 中,forward() 是实际执行模型计算的地方。
具体实现
  • SAGE 类中,forward() 实现了以下逻辑:
    1. 将输入特征 x 赋值给 h
    2. 依次通过三个 SAGEConv 层,每层之间使用 ReLU 激活函数和 Dropout
    3. 最后一层不使用 ReLUDropout,直接返回输出。
代码示例
def forward(self, sg, x):
    h = x
    for l, layer in enumerate(self.layers):
        h = layer(sg, h)
        if l != len(self.layers) - 1:
            h = F.relu(h)
            h = self.dropout(h)
    return h
参数说明
  • sg:输入的图结构(DGLGraph 对象)。
  • x:输入的节点特征(Tensor 对象)。

3. __init__()forward() 的联系与区别

联系
  • __init__()forward() 共同定义了模型的结构和行为。
  • __init__() 中定义的层(如 SAGEConvDropout)会在 forward() 中被调用,用于计算输出。
区别
特性__init__()forward()
调用时机在创建模型实例时调用(初始化阶段)。在模型推理或训练时调用(前向传播阶段)。
主要任务定义模型的结构和参数。定义如何从输入计算输出。
是否显式调用不需要显式调用,由 Python 自动调用。不需要显式调用,由 PyTorch 自动调用。
参数接收模型的超参数(如输入维度、隐藏层维度)。接收输入数据(如图结构、节点特征)。

4. 如何使用 __init__()forward()

步骤 1:定义模型
  • __init__() 中定义模型的结构。
  • forward() 中定义模型的计算逻辑。
步骤 2:创建模型实例
  • 通过传入超参数创建模型实例。
步骤 3:调用模型
  • 直接调用模型实例(如 model(sg, x)),PyTorch 会自动调用 forward() 函数。
代码示例
# 定义模型
model = SAGE(in_feats=128, n_hidden=64, n_classes=10)

# 输入数据
sg = dgl.graph(...)  # 图结构
x = torch.randn(100, 128)  # 节点特征

# 前向传播
output = model(sg, x)  # 自动调用 forward()
print(output.shape)  # 输出形状

5. 总结

  • __init__() 负责初始化模型的结构和参数。
  • forward() 负责定义模型的计算逻辑。
  • 两者共同构成了 PyTorch 模型的核心部分。
  • 使用时,只需定义好模型的结构和计算逻辑,PyTorch 会自动处理前向传播的过程。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值