多头注意力(Multi-Head Attention, MHA)
多头注意力(Multi-Head Attention, MHA) 是 Transformer 模型(如 BERT 和 GPT)中的核心机制,它扩展了缩放点积注意力(Scaled Dot-Product Attention),使模型可以从多个不同的角度关注输入序列的不同部分。这种机制提高了模型的表达能力,使其能够同时捕获不同的语义信息。
1. 为什么需要多头注意力?
在标准的缩放点积注意力(Scaled Dot-Product Attention)中,查询(Query)、键(Key)和值(Value)之间的关系是通过单一的注意力机制计算的。但这种单一的注意力机制存在一定的局限性:
- 它只能关注输入序列中的一种信息模式(例如,长距离依赖)。
- 在实际任务(如机器翻译、文本理解)中,不同的单词可能会以不同的方式关注上下文。
为了解决这个问题,多头注意力通过 多个不同的注意力头(heads) 来分别学习不同的注意力权重,并最终将它们组合在一起,从而增强模型的表示能力。
2. 多头注意力的计算过程
多头注意力的核心思想是:
- 对输入进行线性变换,将输入序列映射到多个低维查询(Q)、键(K)和值(V) 空间。
- 在多个注意力头(Heads)上计算注意力,每个头使用缩放点积注意力(Scaled Dot-Product Attention)。
- 将多个头的输出拼接(Concat)并投影回输出维度,以形成最终的多头注意力输出。
数学公式
假设输入张量的维度为 d model d_{\text{model}} dmodel,多头注意力的计算步骤如下:
-
输入线性变换:
Q i = X W i Q , K i = X W i K , V i = X W i V Q_i = X W_i^Q, \quad K_i = X W_i^K, \quad V_i = X W_i^V Qi=XWiQ,Ki=XWiK,Vi=XWiV
其中:-
X
X
X 是输入序列(形状:
(batch_size, seq_len, d_model)
)。 - W i Q , W i K , W i V W_i^Q, W_i^K, W_i^V WiQ,WiK,WiV 是不同的线性变换矩阵,用于生成查询、键和值(形状: ( d model , d head ) (d_{\text{model}}, d_{\text{head}}) (dmodel,dhead))。
- d head = d model / h d_{\text{head}} = d_{\text{model}} / h dhead=dmodel/h(其中 h h h 是注意力头的数量)。
-
X
X
X 是输入序列(形状:
-
计算缩放点积注意力(Scaled Dot-Product Attention):
Attention ( Q i , K i , V i ) = softmax ( Q i K i T d head ) V i \text{Attention}(Q_i, K_i, V_i) = \text{softmax} \left(\frac{Q_i K_i^T}{\sqrt{d_{\text{head}}}} \right) V_i Attention(Qi,Ki,Vi)=softmax(dheadQiKiT)Vi
这个步骤与标准的缩放点积注意力相同,只是每个注意力头都独立计算自己的注意力分数。 -
拼接(Concat)多个头的输出:
MultiHead ( Q , K , V ) = Concat ( head 1 , head 2 , … , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat} (\text{head}_1, \text{head}_2, \dots, \text{head}_h) W^O MultiHead(Q,K,V)=Concat(head1,head2,…,headh)WO
其中:Concat
操作将所有头的输出沿最后一个维度拼接(形状变回(batch_size, seq_len, d_model)
)。- W O W^O WO 是投影矩阵,用于将拼接后的张量映射回原始维度。
3. PyTorch 代码实现
PyTorch 提供了 torch.nn.MultiheadAttention
模块,可以直接使用多头注意力。但我们先来看手动实现的多头注意力:
手动实现多头注意力
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除!"
self.d_model = d_model
self.num_heads = num_heads
self.d_head = d_model // num_heads # 每个头的维度
# 线性变换矩阵
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, Q, K, V):
"""
计算缩放点积注意力
"""
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_head, dtype=torch.float32))
attention_weights = F.softmax(scores, dim=-1)
return torch.matmul(attention_weights, V), attention_weights
def forward(self, Q, K, V):
batch_size = Q.shape[0]
# 线性变换
Q = self.W_q(Q) # (batch_size, seq_len, d_model)
K = self.W_k(K)
V = self.W_v(V)
# 拆分成多个头 (batch_size, seq_len, num_heads, d_head)
Q = Q.view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_head).transpose(1, 2)
# 计算注意力
attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V)
# 重新组合多头 (batch_size, seq_len, d_model)
attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
# 最终的线性变换
return self.W_o(attention_output), attention_weights
# 测试模型
batch_size = 2
seq_len = 5
d_model = 16
num_heads = 4
multihead_attention = MultiHeadAttention(d_model, num_heads)
X = torch.randn(batch_size, seq_len, d_model)
output, attn_weights = multihead_attention(X, X, X)
print("Output shape:", output.shape)
print("Attention Weights shape:", attn_weights.shape)
4. PyTorch 官方 torch.nn.MultiheadAttention
在实际项目中,可以使用 torch.nn.MultiheadAttention
来实现多头注意力:
import torch
import torch.nn as nn
# 定义多头注意力
d_model = 16
num_heads = 4
multihead_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=num_heads, batch_first=True)
# 创建随机输入 (batch_size, seq_len, d_model)
X = torch.randn(2, 5, 16)
# 计算注意力
output, attn_weights = multihead_attn(X, X, X)
print("Output shape:", output.shape)
print("Attention Weights shape:", attn_weights.shape)
5. 总结
- 多头注意力(Multi-Head Attention) 通过多个注意力头捕获不同的语义信息,使得模型可以同时关注输入序列的不同部分。
- 计算过程:
- 线性变换得到多个查询(Q)、键(K)、值(V)。
- 计算多个缩放点积注意力(Scaled Dot-Product Attention)。
- 拼接所有注意力头的输出,并通过一个线性变换映射回原始维度。
- PyTorch 提供了官方实现,可以直接使用
torch.nn.MultiheadAttention
进行调用。
多头注意力是 Transformer 的核心部分,使其能够高效建模长序列数据,是当前 NLP 和 CV 任务的关键技术之一。