pytorch Module中的forward使用for循环与不使用for循环的区别

本文通过实例探讨了PyTorch中Module的__init__()和forward()函数的区别。作者发现,__init__()定义了Module的网络结构,而forward()决定了网络的执行流程。在forward()的for循环里,输入数据是通过同一个网络进行处理,而非每次都使用新的网络。这有助于初学者理解PyTorch模块的内部工作原理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

作为初学者,看代码一度很迷惑,module中的forward函数中for循环,输入的Tensor数据是在同一个网络循环,还是依次向前推进了多个不同的网络。于是,我经过了下面的测试。

import torch.nn as nn


class Model1(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()

        self.block = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        x = self.block(x) + x
        x = self.block(x) + x
        x = self.block(x) + x
        return x


class Model2(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()

        self.block = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        for i in range(3):
            x = self.block(x) + x
        return x


model1 = Model1(hidden_dim=10)
model2=Model2(hidden_dim=10)
print(model1)
print(model2)

得到以下结果:

Model1(
    (block): Linear(in_feature=10, out_feature=10, bias=True)
)
Model2(
    (block): Linear(in_feature=10, out_feature=10, bias=True)
)

然后我就悟了!一个Module的结构到底是由什么构成的,是__init__()还是forward()?结论是__init__()决定了Module有哪些网络,forward()决定了Module的网络是如何连接的。在forward()中无论如何调用__init__()中定义的某个网络,始终都是同一个网络。

那么文章开头那个问题的答案就有了,答案是:for循环中,通过的是同一个网络。

SRN (Simple Recurrent Network) 简单循环神经网络是一种特殊的递归神经网络,它通过时间步骤处理序列数据,常用于自然语言处理和音频信号处理等场景。在 PyTorch 中实现 SRN 的核心算子包括: 1. **Linear Layer (全连接层)**:这是基本的线性变换,`nn.Linear()` 类在 PyTorch 中实现,它会计算输入和权重矩阵的乘积,并加上偏置项。 2. **Activation Function**:例如 `nn.ReLU()` 或 `nn.Tanh()`,它们用于引入非线性,使模型能够学习更复杂的函数映射。 3. **LSTM Cell (长短期记忆单元)**:如果你要用到 LSTM 这种更复杂的循环结构,可以使用 `nn.LSTMCell()`。LSTM 有细胞状态(cell state)和隐藏状态(hidden state),包含 Forget Gate、Input Gate、Output Gate 和 Update Gate 等门控机制。 4. **RNN Module (如 nn.RNN or nn.GRU)**:PyTorch 提供了 `nn.RNN` 和 `nn.GRU` 模块,用于封装上述操作并自动处理梯度的计算,简化了循环网络的设计。 5. **TimeDistributed Wrapper**:对于需要跨时间步应用的一层,可以使用 `nn.TimeDistributed` 来包裹,使其可以接收到整个序列作为输入。 6. **Sequence Packing and Unpacking**:由于 RNN 需要顺序处理数据,可能会遇到长度一致的序列。这时可以使用 `nn.utils.rnn.pack_padded_sequence` 和 `nn.utils.rnn.pad_packed_sequence` 进行打包和拆包。 实现一个简单的 SRN 可能会涉及到定义网络架构、初始化参数、前向传播以及反向传播。以下是简化的示例代码片段: ```python import torch.nn as nn class SimpleRNN(nn.Module): def __init__(self, input_size, hidden_size): super(SimpleRNN, self).__init__() self.hidden_size = hidden_size self.rnn = nn.RNN(input_size, hidden_size) def forward(self, inputs, h0=None): output, _ = self.rnn(inputs, h0) return output # 使用例子 input_seq = torch.randn(10, 32, 64) # 10个时间步,每个时间步32维特征 model = SimpleRNN(64, 128) h0 = torch.zeros(1, input_seq.size(0), model.hidden_size) # 初始化隐状态 output = model(input_seq, h0) ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值