多头注意力(Multi-Head Attention, MHA)

多头注意力(Multi-Head Attention, MHA)是 Transformer 模型的核心机制之一,它通过多个注意力头(Attention Heads)并行计算,使模型能够关注输入序列的不同部分,从而增强学习能力。

1. 计算公式

多头注意力的核心计算公式如下:

其中:

  • Q(Query):查询矩阵,表示当前词或句子想要关注的信息。
  • K(Key):键矩阵,表示所有词的特征。
  • V(Value):值矩阵,表示注意力加权后的最终特征。
  • d_{k}是每个注意力头的维度,\sqrt{d_{k}}作为缩放因子,防止梯度消失/爆炸。

多头注意力中,我们不止计算一次注意力,而是用多个不同的Q,K,V计算多个注意力头,每个注意力头学习不同的特征。

2. 多头注意力的计算过程

假设:

  • 词向量维度 d_{model}=512
  • 头数 h=8
  • 每个注意力头的维度 d_{k}=d_{model}/h=64

计算流程:

  1. 输入嵌入

    • 句子中的每个单词用 d_{model} 维的向量表示,例如: X=[x_{1},x_{2},...,x_{n}],X\in \mathbb{R}^{n\times d_{model}}
    • 其中 n 是序列长度。
  2. 线性变换

    • 用不同的参数矩阵,将输入映射到不同的子空间:Q=XW_{Q}, K=XW_{K}, V=XW_{V}
    • 其中:
      • W_{Q}, W_{K}, W_{V}\in \mathbb{R}^{d_{model}\times d_{model}}
      • 这些参数是可训练的。
  3. 拆分成多个头

    • 我们将Q, K, V按照头数拆分:Q, K, V\in \mathbb{R}^{n\times d_{model}}\rightarrow \mathbb{R}^{n\times h\times d_{k}}
    • 例如,假设d_{model}=512, h=8, d_{k}=64则:Q, K, V\in \mathbb{R}^{n\times 512}\rightarrow 8\times (n\times 64)
    • 这样,每个注意力头计算一个 d_{k}- 维的注意力。
  4. 计算每个头的注意力

    • 计算每个头的注意力权重:

    • 其中:
      • QK^{T}计算 Query 和 Key 之间的相似度(点积)。
      • softmax 归一化后得到注意力权重。
      • 乘以 V获取新的表示。
  5. 合并多头输出

    • 所有头计算完成后,拼接结果:Concat(head1,...,head h) 
    • 这样我们得到形状为 \mathbb{R}^{n\times d_{model}}的矩阵。
  6. 最终线性变换

    • 使用一个变换矩阵 W_{O} 进行线性变换:Output=Concat(head1,...,head h)W_{O} 
    • 这里W_{O}也是可训练参数,最终的输出仍然是 d_{model}维。

3. PyTorch 代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 每个头的维度

        self.W_q = nn.Linear(d_model, d_model)  # 线性变换 Q
        self.W_k = nn.Linear(d_model, d_model)  # 线性变换 K
        self.W_v = nn.Linear(d_model, d_model)  # 线性变换 V
        self.W_o = nn.Linear(d_model, d_model)  # 最终变换矩阵

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        
        # 计算 Q, K, V
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        
        # 应用 Mask(用于解码器防止看到未来信息)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)

        # 还原维度
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.W_o(out)

# 测试
batch_size, seq_len, d_model = 2, 5, 512
num_heads = 8

x = torch.rand(batch_size, seq_len, d_model)
attn = MultiHeadAttention(d_model, num_heads)
output = attn(x)
print(output.shape)  # 应该输出: [2, 5, 512]

4. 为什么使用多头注意力?

相比于单头注意力,多头注意力的优点:

  1. 捕捉不同层次的信息:每个头可以学习不同的注意力模式,如关注不同位置的单词。
  2. 增强模型的表达能力:不同头的学习能力互补,使 Transformer 更强大。
  3. 避免单头注意力的局限:单头注意力可能会过度关注某些特定部分,多头可以让不同部分的信息融合。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值