Transformer 解码器深度解读 + 代码实战
1. 解码器核心作用
Transformer 解码器的核心任务是基于编码器的语义表示逐步生成目标序列(如翻译结果、文本续写)。它通过 掩码自注意力 和 编码器-解码器交叉注意力,实现自回归生成并融合源序列信息。与编码器的核心差异:
- 掩码机制:防止解码时看到未来信息(训练时并行,推理时逐步生成)。
- 交叉注意力:将编码器输出作为 Key/Value,解码器当前状态作为 Query。
2. 解码器单层结构详解
每层解码器包含以下模块(附 PyTorch 代码):
2.1 掩码多头自注意力(Masked Multi-Head Self-Attention)
class MaskedMultiHeadAttention(nn.Module):
def __init__(self, embed_size, heads):
super().__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
# 生成 Q/K/V 的线性层
self.to_qkv = nn.Linear(embed_size, embed_size * 3)
self.scale = self.head_dim ** -0.5 # 缩放因子
# 输出线性层
self.to_out = nn.Linear(embed_size, embed_size)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# 生成 Q/K/V 并分割多头
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: t.view(batch_size, seq_len, self.heads, self.head_dim), qkv)
# 计算注意力分数 QK^T / sqrt(d_k)
attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
# 应用下三角掩码(防止看到未来信息)
if mask is not None:
attn = attn.masked_fill(mask == 0, -1e10) # 掩码位置填充极小值