多头注意力(Multi-Head Attention, MHA)是 Transformer 模型的核心机制之一,它通过多个注意力头(Attention Heads)并行计算,使模型能够关注输入序列的不同部分,从而增强学习能力。
1. 计算公式
多头注意力的核心计算公式如下:
其中:
(Query):查询矩阵,表示当前词或句子想要关注的信息。
(Key):键矩阵,表示所有词的特征。
(Value):值矩阵,表示注意力加权后的最终特征。
是每个注意力头的维度,
作为缩放因子,防止梯度消失/爆炸。
在多头注意力中,我们不止计算一次注意力,而是用多个不同的计算多个注意力头,每个注意力头学习不同的特征。
2. 多头注意力的计算过程
假设:
- 词向量维度
- 头数
- 每个注意力头的维度
计算流程:
-
输入嵌入
- 句子中的每个单词用
维的向量表示,例如:
- 其中 n 是序列长度。
- 句子中的每个单词用
-
线性变换
- 用不同的参数矩阵,将输入映射到不同的子空间:
- 其中:
- 这些参数是可训练的。
- 用不同的参数矩阵,将输入映射到不同的子空间:
-
拆分成多个头
- 我们将
按照头数拆分:
- 例如,假设
则:
- 这样,每个注意力头计算一个
维的注意力。
- 我们将
-
计算每个头的注意力
- 计算每个头的注意力权重:
- 其中:
计算 Query 和 Key 之间的相似度(点积)。
- softmax 归一化后得到注意力权重。
- 乘以
获取新的表示。
- 计算每个头的注意力权重:
-
合并多头输出
- 所有头计算完成后,拼接结果:
- 这样我们得到形状为
的矩阵。
- 所有头计算完成后,拼接结果:
-
最终线性变换
- 使用一个变换矩阵
进行线性变换:
- 这里
也是可训练参数,最终的输出仍然是
维。
- 使用一个变换矩阵
3. PyTorch 代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 每个头的维度
self.W_q = nn.Linear(d_model, d_model) # 线性变换 Q
self.W_k = nn.Linear(d_model, d_model) # 线性变换 K
self.W_v = nn.Linear(d_model, d_model) # 线性变换 V
self.W_o = nn.Linear(d_model, d_model) # 最终变换矩阵
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# 计算 Q, K, V
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
# 应用 Mask(用于解码器防止看到未来信息)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
# 还原维度
out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
return self.W_o(out)
# 测试
batch_size, seq_len, d_model = 2, 5, 512
num_heads = 8
x = torch.rand(batch_size, seq_len, d_model)
attn = MultiHeadAttention(d_model, num_heads)
output = attn(x)
print(output.shape) # 应该输出: [2, 5, 512]
4. 为什么使用多头注意力?
相比于单头注意力,多头注意力的优点:
- 捕捉不同层次的信息:每个头可以学习不同的注意力模式,如关注不同位置的单词。
- 增强模型的表达能力:不同头的学习能力互补,使 Transformer 更强大。
- 避免单头注意力的局限:单头注意力可能会过度关注某些特定部分,多头可以让不同部分的信息融合。