Transformer解码器终极指南:从Masked Attention到Cross-Attention的PyTorch逐行实现


Transformer 解码器深度解读 + 代码实战


1. 解码器核心作用

Transformer 解码器的核心任务是基于编码器的语义表示逐步生成目标序列(如翻译结果、文本续写)。它通过 掩码自注意力编码器-解码器交叉注意力,实现自回归生成并融合源序列信息。与编码器的核心差异:

  • 掩码机制:防止解码时看到未来信息(训练时并行,推理时逐步生成)。
  • 交叉注意力:将编码器输出作为 Key/Value,解码器当前状态作为 Query。

2. 解码器单层结构详解

每层解码器包含以下模块(附 PyTorch 代码):


2.1 掩码多头自注意力(Masked Multi-Head Self-Attention)
class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        # 生成 Q/K/V 的线性层
        self.to_qkv = nn.Linear(embed_size, embed_size * 3)
        self.scale = self.head_dim ** -0.5  # 缩放因子
        
        # 输出线性层
        self.to_out = nn.Linear(embed_size, embed_size)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        # 生成 Q/K/V 并分割多头
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.view(batch_size, seq_len, self.heads, self.head_dim), qkv)
        
        # 计算注意力分数 QK^T / sqrt(d_k)
        attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        
        # 应用下三角掩码(防止看到未来信息)
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e10)  # 掩码位置填充极小值
        
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值