同一句话里,每个词都会去问“我该关注谁?”——但不同专家(不同“头”)关注点不一样:有人看语法,有人看短语,有人盯实体名词。
一分钟直觉
• 三张卡片:对每个词做三份向量
• Q(Query):我想找什么
• K(Key):我是什么特征
• V(Value):我携带的信息
• 打分→归一化→加权求和:
1. 用 Q 和别人的 K 做相似度打分;
2. softmax 变成“关注权重”(总和=1);
3. 用这些权重把别人的 V 加权求和,得到这个词的新表示。
• 多头 = 多位专家:把总维度分成几份(每份叫一个“头”),各头独立做一遍上面流程。最后把各头结果拼起来,就是“多头注意力”。
一个小例子(不写公式也能懂)
句子:“她 喜欢 吃 苹果”
• 头A(语法专家)
• “喜欢”多看主语“她”和宾语“吃/苹果”
• 头B(短语专家)
• 更关注“吃↔苹果”这个搭配
同一句话,两位专家关注到的关系不一样但互补,拼在一起表示力更强。
为啥要多头
• 多视角:有人看长距离依赖,有人看邻近短语,有人看实体/位置;
• 并行好算:把维度平均分给各头,总计算量大致不变,但学到的模式更丰富;
• 信息通路短:任意两个词一层就能互相“看见”。
自注意力 vs 交叉注意力
• 自注意力:Q=K=V 都来自同一序列(常见于编码器/解码器自回归)。
• 交叉注意力:Q 来自目标序列,K/V 来自源序列(翻译里连接编码器和解码器)。
新手容易踩的点
• 维度要对上:d_model = num_heads * d_head(总维度等于头数×每头维度)。
• mask 别忘:自回归或padding位置要屏蔽(不然会“偷看”未来或看无效位)。
• 头不是越多越好:太多会让每头过窄、学不到东西;常见搭配:d_model=512 -> heads=8。
• 推理内存:KV 缓存随头数增长;服务端常用 GQA/MQA 来减小内存占用。
10 行感受一下(PyTorch)
import torch
from torch.nn import MultiheadAttention
L, N, d_model, h = 5, 1, 64, 8 # 序列长5, batch 1, 维度64, 8头
x = torch.randn(L, N, d_model) # 形状要求: [seq_len, batch, d_model]
mha = MultiheadAttention(embed_dim=d_model, num_heads=h, batch_first=False)
y, attn = mha(x, x, x, need_weights=True) # 自注意力
print(y.shape) # -> [L, N, d_model]
print(attn.shape) # -> [N*h, L, L] 每个头一张"谁关注谁"的热力图
小练习:把 num_heads 改成 1/4/8,对比 attn 的差异;或把句子换成“吃 红色 的 苹果”,看看不同头会不会更关注“红色↔苹果”。
口袋速记
• Q 找人,K 说明书,V 真信息;
• 打分-软化-加权 得到新表示;
• 多头=多专家并行,最后拼接。
多头注意力机制详解
1027

被折叠的 条评论
为什么被折叠?



