单头注意力的输入是一整个序列的词向量矩阵 X,形状为 (n, d_model),其中 n 是序列长度,d_model 是向量维度。它的目标是为这个序列中的每一个词向量,都计算出一个融合了全局上下文的新版本。
现在,我们从这个精确的起点出发,一步步构建多头注意力。
第一层:准备阶段 —— 原始的 Q, K, V
在多头注意力机制开始之前,我们还是和单头一样,需要先生成总的 Query, Key, Value 矩阵。
- 输入: 词向量序列
X,形状(n, d_model)。 - 三大权重矩阵:
W_q(Query 权重),形状(d_model, d_model)W_k(Key 权重),形状(d_model, d_model)W_v(Value 权重),形状(d_model, d_model)- 注意:这里的输出维度我们先简化为
d_model,后面再细化。
- 线性投影:
Q_total = X @ W_q,形状(n, d_model)K_total = X @ W_k,形状(n, d_model)V_total = X @ W_v,形状(n, d_model)
到此为止,一切都和单头注意力完全一样。我们得到了代表整个序列的、总的 Q, K, V 矩阵。
第二层:核心操作 —— “分头” (Splitting into Heads)
这是多头与单头的第一个关键区别。我们不直接使用庞大的 Q_total, K_total, V_total,而是把它们“劈”成 h 份(h 是头的数量,比如 8)。
怎么“劈”?
这不是真的把矩阵切开,而是通过一个巧妙的维度重塑 (Reshape) 来实现。
- 设定: 假设
h = 8,d_model = 512。那么每个头的维度d_head = d_model / h = 512 / 8 = 64。 - 重塑
Q_total:- 原始形状:
(n, 512) - 步骤 a (Reshape): 将最后一个维度 512 拆成
h和d_head。
Q_reshaped形状变为(n, 8, 64)。 - 步骤 b (Transpose): 为了让所有头能并行计算,我们把
h这个维度换到前面。
Q_split形状变为(8, n, 64)。
- 原始形状:
对 K_total 和 V_total 做完全相同的操作,得到:
K_split,形状(8, n, 64)V_split,形状(8, n, 64)
底层含义:
我们现在得到了 8 组独立的、低维度的 (q, k, v) 集合。每一组都可以被看作一个独立的“专家”(一个注意力头)。第 i 组 (Q_split[i], K_split[i], V_split[i]) 的形状都是 (n, 64)。
第三层:并行计算 —— 各自独立的注意力
现在,这 8 个头并行地、独立地执行我们熟悉的 scaled dot-product attention 计算。PyTorch 等框架的强大之处在于,它能利用 GPU 的并行能力,把这 8 次计算看作一个批处理操作一次性完成。
对于第 i 个头 ( i from 0 to 7):
-
计算得分:
Scores_i = Q_split[i] @ K_split[i]^T
维度:(n, 64) @ (64, n)->(n, n) -
缩放:
Scores_i_scaled = Scores_i / sqrt(d_head)(注意,这里除的是d_head=64的平方根,而不是d_model=512) -
Softmax:
Weights_i = softmax(Scores_i_scaled, dim=-1)
维度:(n, n) -
加权求和 V:
Z_i = Weights_i @ V_split[i]
维度:(n, n) @ (n, 64)->(n, 64)
计算完成后,我们得到了 8 个输出矩阵 Z_0, Z_1, ..., Z_7,每一个的形状都是 (n, 64)。Z_i 代表了第 i 个专家(头)对序列的分析结果。
第四层:整合阶段 —— “合并”与最终投影
这是多头与单头的第二个关键区别。我们需要把 8 个专家各自的分析报告整合起来,形成一个统一的最终结论。
-
拼接 (Concatenation):
- 我们将 8 个
Z_i矩阵在最后一个维度上拼接起来。 Z_concat = Concat(Z_0, Z_1, ..., Z_7)- 维度:
(n, 64)的 8 个矩阵拼接后,形状变回(n, 512)。 - 在代码实现中,这一步通常是第二步“分头”时
reshape和transpose的逆操作。
- 我们将 8 个
-
最终线性投影:
- 仅仅拼接是不够的,因为这只是把 8 份报告简单地订在一起。我们需要一个“总编辑”来审阅和融合这些报告。
- 这个“总编辑”就是另一个可学习的权重矩阵
W_o(Output 权重),形状为(d_model, d_model),即(512, 512)。 Z_final = Z_concat @ W_o- 维度:
(n, 512) @ (512, 512)->(n, 512)。
Z_final 就是多头注意力模块的最终输出。它的形状和原始输入 X 完全一样,可以无缝地送入下一个层(比如 FFN 层或下一个 Transformer Block)。
总结从底层开始的流程:
X
↓ (Projection with Wq, Wk, Wv)
Q_total, K_total, V_total (all n x 512)
↓ (Reshape & Transpose)
Q_split, K_split, V_split (all 8 x n x 64)
↓ (Parallel Scaled Dot-Product Attention)
* for each head i in 8:
* Weights_i = softmax( (Q_split[i] @ K_split[i]^T) / sqrt(64) )
* Z_i = Weights_i @ V_split[i] ( Z_i is n x 64 )
↓ (Concatenate)
Z_concat ( n x 512 )
↓ (Final Projection with Wo)
Z_final ( n x 512 )
8454

被折叠的 条评论
为什么被折叠?



