<think>嗯,用户让我详细介绍一下Transformer模型的代码。首先,我需要确认他们对Transformer的理解程度。可能他们已经有了一些理论基础,现在想通过代码来加深理解。我得先从模型的结构开始,分模块讲解,比如自注意力机制、多头注意力、位置编码等等。
用户可能希望看到具体的代码示例,而不仅仅是理论。所以,我需要用PyTorch这样的常见框架来展示,这样更容易上手。同时,要注意代码的注释,解释每一步的作用,比如Q、K、V的计算,缩放点积注意力,还有多头注意力的实现方式。
还要考虑用户可能对某些细节不太清楚,比如位置编码的实现,或者前馈网络的结构。这部分需要详细说明,比如使用正弦函数和余弦函数来生成位置信息,以及如何将位置编码加到输入向量上。
另外,用户可能想知道各个模块是如何组合在一起的,比如编码器和解码器的堆叠,掩码的使用等等。这里需要展示完整的Transformer类,并解释每个参数的作用,比如词嵌入维度、头数、层数等。
在写代码的时候,要注意遵循系统级指令,比如数学公式用$...$,独立公式用$$...$$。例如,在解释缩放点积注意力时,公式要用LaTeX正确表示。同时,结构要清晰,分步骤讲解,确保用户能一步步理解。
可能还需要指出一些常见的实现细节,比如为什么使用Layer Normalization,残差连接的作用,以及训练和推理时解码器的不同处理(掩码的使用)。这些点对于正确实现Transformer模型非常重要,用户可能会在这里遇到问题。
最后,要提醒用户在实际应用中的注意事项,比如参数调整、数据预处理、训练技巧等。虽然用户主要问的是代码详解,但提供一些实用建议会更有帮助。同时,要确保代码示例正确,没有语法错误,用户可以复制粘贴后进行测试。</think>我将分步骤为您详解Transformer模型的核心代码实现(基于PyTorch框架)。让我们从关键组件到完整架构逐步解析:
```python
import torch
import torch.nn as nn
import math
# ======================
# 核心组件实现
# ======================
# 1. 缩放点积注意力(Scaled Dot-Product Attention)
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
计算:$Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V$
"""
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attention = torch.softmax(scores, dim=-1)
return torch.matmul(attention, V)
# 2. 多头注意力机制
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model # 模型维度,如512
self.num_heads = num_heads # 头数,如8
self.d_k = d_model // num_heads
# 线性变换矩阵
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
def split_heads(self, x):
"""
将输入张量重塑为:(batch_size, num_heads, seq_len, d_k)
"""
batch_size = x.size(0)
return x.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
def forward(self, Q, K, V, mask=None):
# 线性变换 + 分头
Q = self.split_heads(self.W_Q(Q)) # (B, h, L, d_k)
K = self.split_heads(self.W_K(K))
V = self.split_heads(self.W_V(V))
# 计算注意力
attention = scaled_dot_product_attention(Q, K, V, mask)
# 拼接多头结果
attention = attention.transpose(1, 2).contiguous() # (B, L, h, d_k)
original_size = (attention.size(0), -1, self.d_model)
output = self.W_O(attention.view(*original_size))
return output
# 3. 位置前馈网络
class PositionWiseFFN(nn.Module):
def __init__(self, d_model, d_ff=2048):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
# 4. 位置编码
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=512):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0)/d_model))
pe[:, 0::2] = torch.sin(position * div_term) # 偶数位置
pe[:, 1::2] = torch.cos(position * div_term) # 奇数位置
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)] # 加入位置编码
# ======================
# Transformer完整架构
# ======================
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = PositionWiseFFN(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 自注意力 + 残差连接
attn_output = self.self_attn(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_output))
# 前馈网络 + 残差连接
ffn_output = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_output))
return x
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff=2048, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = PositionWiseFFN(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output, src_mask, tgt_mask):
# 自注意力(解码器端)
attn_output = self.self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(attn_output))
# 交叉注意力(编码器-解码器)
attn_output = self.cross_attn(x, encoder_output, encoder_output, src_mask)
x = self.norm2(x + self.dropout(attn_output))
# 前馈网络
ffn_output = self.ffn(x)
x = self.norm3(x + self.dropout(ffn_output))
return x
class Transformer(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8,
num_layers=6, d_ff=2048, dropout=0.1):
super().__init__()
# 编码器部分
self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model)
self.encoder_layers = nn.ModuleList(
[TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)])
self.decoder_layers = nn.ModuleList(
[TransformerDecoderLayer(d_model, num_heads, d_ff, dropout)
for _ in range(num_layers)])
self.final_linear = nn.Linear(d_model, tgt_vocab_size)
self.dropout = nn.Dropout(dropout)
def encode(self, src, src_mask):
x = self.encoder_embedding(src)
x = self.pos_encoding(x)
for layer in self.encoder_layers:
x = layer(x, src_mask)
return x
def decode(self, tgt, encoder_output, src_mask, tgt_mask):
x = self.decoder_embedding(tgt)
x = self.pos_encoding(x)
for layer in self.decoder_layers:
x = layer(x, encoder_output, src_mask, tgt_mask)
return x
def forward(self, src, tgt, src_mask, tgt_mask):
encoder_output = self.encode(src, src_mask)
decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
return self.final_linear(decoder_output)
```
关键实现细节说明:
1. **缩放点积注意力**:
- 使用$\sqrt{d_k}$进行缩放,防止softmax梯度消失
- 通过mask实现padding mask和sequence mask
2. **多头注意力**:
- 将输入拆分为多个头并行计算
- 最后通过线性层合并结果
- 计算复杂度:$O(L^2 \cdot d)$
3. **位置编码**:
- 使用正弦/余弦函数的固定编码方案
- 公式:$$PE_{(pos,2i)} = \sin(pos/10000^{2i/d_{model}})$$
$$PE_{(pos,2i+1)} = \cos(pos/10000^{2i/d_{model}})$$
4. **前馈网络**:
- 包含两个线性变换和ReLU激活
- 中间维度通常设为4倍模型维度(如512→2048)
5. **残差连接与层归一化**:
- 每个子层输出:LayerNorm(x + Sublayer(x))
- 帮助缓解梯度消失问题
使用示例:
```python
# 参数设置
d_model = 512
num_heads = 8
num_layers = 6
src_vocab = 10000
tgt_vocab = 8000
model = Transformer(src_vocab, tgt_vocab, d_model, num_heads, num_layers)
# 假设输入
src = torch.randint(0, src_vocab, (32, 100)) # batch_size=32, seq_len=100
tgt = torch.randint(0, tgt_vocab, (32, 90))
src_mask = torch.ones(32, 1, 100) # 假设没有padding
tgt_mask = torch.tril(torch.ones(90, 90)) # 解码器掩码
output = model(src, tgt, src_mask, tgt_mask)
print(output.shape) # torch.Size([32, 90, 8000])
```
实际应用注意事项:
1. 需要根据任务调整超参数(层数、头数、维度等)
2. 训练时需处理不同长度的序列(padding mask)
3. 解码时使用自回归生成(teacher forcing或beam search)
4. 建议使用学习率warmup策略
5. 对于大数据集,推荐使用混合精度训练
此实现完整展示了Transformer的核心架构,可在此基础上进行修改以适应不同任务(如BERT、GPT等变体)。