### PyTorch Transformer Forward 方法实现
在 PyTorch 中,`Transformer` 的 `forward` 方法通常涉及编码器 (`Encoder`) 和解码器 (`Decoder`) 结构的交互。以下是基于提供的引用内容以及标准实践的一个典型实现方式。
#### 编码器-解码器结构中的 `forward` 方法
在一个典型的 Transformer 实现中,`forward` 方法会接收输入张量并返回经过模型处理后的输出张量。具体来说:
1. **编码器部分**
输入序列通过嵌入层和位置编码后送入编码器堆栈。编码器的主要功能是对输入序列进行特征提取,并生成上下文表示(`enc_output`)[^3]。
2. **解码器部分**
解码器接受目标序列作为输入,并结合来自编码器的上下文向量(`enc_output`),逐步生成输出序列。在此过程中,掩码机制被广泛应用于防止解码器访问未来的信息[^4]。
3. **前向传播流程**
下面是一个简化版的 `forward` 方法实现,展示了如何将编码器输出传递给解码器,并应用掩码操作:
```python
import torch
import torch.nn as nn
class Transformer(nn.Module):
def __init__(self, encoder_layer, decoder_layer, num_encoder_layers, num_decoder_layers,
d_model, nhead, dim_feedforward, dropout=0.1):
super(Transformer, self).__init__()
self.encoder = nn.TransformerEncoder(encoder_layer, num_encoder_layers)
self.decoder = nn.TransformerDecoder(decoder_layer, num_decoder_layers)
def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None,
src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
"""
Args:
src: the sequence to the encoder (required).
tgt: the sequence to the decoder (required).
src_mask: the additive mask for the src sequence (optional).
tgt_mask: the additive mask for the tgt sequence (optional).
memory_mask: the additive mask for the encoder output (optional).
src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
Shape:
- src: :math:`(S, N, E)`
- tgt: :math:`(T, N, E)`
- Outputs: :math:`(T, N, E)`
where S is the source sequence length, T is the target sequence length,
N is the batch size, and E is the feature number.
"""
# Encoder pass
enc_output = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) # [S, N, E]
# Decoder pass using encoder's output as context/memory
dec_output = self.decoder(tgt, enc_output, tgt_mask=tgt_mask, memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask) # [T, N, E]
return dec_output
```
上述代码实现了以下逻辑:
- 使用 `nn.TransformerEncoder` 对源序列进行编码。
- 将编码器的输出传入到 `nn.TransformerDecoder` 进行解码。
- 支持多种类型的掩码参数以控制信息流[^5]。
#### 关键点解析
- **掩码作用**
掩码主要用于解决自回归建模中的因果关系问题,确保当前时刻仅依赖于过去的历史信息而不会泄露未来的数据。
- **形状说明**
- `src`: `(source_sequence_length, batch_size, embedding_dim)` 表示源序列长度、批次大小和嵌入维度。
- `tgt`: `(target_sequence_length, batch_size, embedding_dim)` 表示目标序列长度、批次大小和嵌入维度。
- 输出同样保持相同的形状 `(target_sequence_length, batch_size, embedding_dim)`。
- **可选参数**
如 `src_mask`, `tgt_mask` 等允许灵活调整模型行为,在实际应用场景下非常有用。
---
### 示例调用
假设已经准备好输入数据 `src` 和 `tgt`,可以按照如下方式进行训练:
```python
# 初始化模型和其他组件
model = Transformer(
encoder_layer=nn.TransformerEncoderLayer(d_model=512, nhead=8),
decoder_layer=nn.TransformerDecoderLayer(d_model=512, nhead=8),
num_encoder_layers=6,
num_decoder_layers=6,
d_model=512,
nhead=8,
dim_feedforward=2048
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
for epoch in range(num_epochs):
optimizer.zero_grad()
# 假设已准备好的输入数据
outputs = model(enc_inputs, dec_inputs, src_mask, tgt_mask) # 调用 forward 方法
loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
loss.backward()
optimizer.step()
```
此代码片段展示了一个完整的训练循环,其中 `model` 是之前定义的 `Transformer` 类实例化对象[^1]。
---