SDPA(Scaled Dot-Product Attention)详解
SDPA(Scaled Dot-Product Attention,缩放点积注意力)是 Transformer 模型的核心计算单元,最早由 Vaswani 等人在 2017 年的论文《Attention Is All You Need》提出。它通过计算查询(Query)、键(Key)和值(Value)之间的相似度,生成上下文感知的表示。
1. SDPA 的数学定义
给定:
- 查询矩阵(Query):Q∈Rn×dkQ \in \mathbb{R}^{n \times d_k}Q∈Rn×dk
- 键矩阵(Key):K∈Rm×dkK \in \mathbb{R}^{m \times d_k}K∈Rm×dk
- 值矩阵(Value):V∈Rm×dvV \in \mathbb{R}^{m \times d_v}V∈Rm×dv
SDPA 的计算公式为:
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) VAttention(Q,K,V)=softmax(dkQKT)V
其中:
- QKTQK^TQKT 计算查询和键的点积(相似度)。
- dk\sqrt{d_k}dk 用于缩放点积,防止梯度消失或爆炸(尤其是 dkd_kdk 较大时)。
- softmax 将注意力权重归一化为概率分布。
- 最终加权求和 VVV 得到输出。
2. SDPA 的计算步骤
- 计算相似度(Dot-Product)
- 计算 QQQ 和 KKK 的点积:
S=QKTS = QK^TS=QKT - 相似度矩阵 S∈Rn×mS \in \mathbb{R}^{n \times m}S∈Rn×m 表示每个查询对所有键的匹配程度。
-
缩放(Scaling)
- 除以 dk\sqrt{d_k}dk(键向量的维度),防止点积值过大导致 softmax 梯度消失:
Sscaled=SdkS_{\text{scaled}} = \frac{S}{\sqrt{d_k}}Sscaled=dkS
- 除以 dk\sqrt{d_k}dk(键向量的维度),防止点积值过大导致 softmax 梯度消失:
-
Softmax 归一化
- 对每行(每个查询)做 softmax,得到注意力权重 AAA:
A=softmax(Sscaled)A = \text{softmax}(S_{\text{scaled}})A=softmax(Sscaled) - 保证 ∑jAi,j=1\sum_j A_{i,j} = 1∑jAi,j=1,权重总和为 1。
- 对每行(每个查询)做 softmax,得到注意力权重 AAA:
-
加权求和(Value 聚合)
- 用注意力权重 AAA 对 VVV 加权求和,得到最终输出:
Output=A⋅V\text{Output} = A \cdot VOutput=A⋅V - 输出维度:Rn×dv\mathbb{R}^{n \times d_v}Rn×dv。
- 用注意力权重 AAA 对 VVV 加权求和,得到最终输出:
3. SDPA 的作用与优势
✅ 核心作用:
- 让模型动态关注输入的不同部分(类似人类注意力机制)。
- 适用于序列数据(如文本、语音、视频),捕捉长距离依赖。
✅ 优势:
- 并行计算友好
- 矩阵乘法(GEMM)可高效并行加速(GPU/TPU 优化)。
- 可解释性
- 注意力权重可视化(如
BertViz
)可分析模型关注哪些 token。
- 注意力权重可视化(如
- 灵活扩展
- 可结合 多头注意力(Multi-Head Attention) 增强表达能力。
4. SDPA 的变体与优化
变体/优化 | 核心改进 | 应用场景 |
---|---|---|
多头注意力(MHA) | 并行多个 SDPA,增强特征多样性 | Transformer (BERT, GPT) |
FlashAttention | 优化内存访问,减少 HBM 读写 | 长序列推理(如 8K+ tokens) |
Sparse Attention | 只计算局部或稀疏的注意力 | 降低计算复杂度(如 Longformer) |
Linear Attention | 用线性近似替代 softmax | 低资源设备(如 RetNet) |
5. 代码实现(PyTorch 示例)
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
return output
# 示例输入
Q = torch.randn(2, 5, 64) # (batch_size, seq_len, d_k)
K = torch.randn(2, 5, 64)
V = torch.randn(2, 5, 128)
output = scaled_dot_product_attention(Q, K, V)
print(output.shape) # torch.Size([2, 5, 128])
6. 总结
- SDPA 是 Transformer 的基石,通过 Query-Key-Value 机制 + Softmax 归一化 实现动态注意力。
- 关键优化点:缩放(防止梯度问题)、并行计算、内存效率(如 FlashAttention)。
- 现代优化(如 SageAttention2)进一步结合 量化、稀疏化、离群值处理 提升效率。
SDPA 及其变体已成为 NLP、CV、多模态领域的核心组件,理解其原理对模型优化至关重要。
SDPA计算过程举例
我们通过一个具体的数值例子,逐步演示 SDPA 的计算过程。假设输入如下(简化版,便于手动计算):
输入数据(假设 d_k = 2
, d_v = 3
)
-
Query (Q):2 个查询(
n=2
),每个查询维度d_k=2
Q=[1234]Q = \begin{bmatrix} 1 & 2 \\ 3 & 4 \\ \end{bmatrix}Q=[1324] -
Key (K):3 个键(
m=3
),每个键维度d_k=2
K=[5678910]K = \begin{bmatrix} 5 & 6 \\ 7 & 8 \\ 9 & 10 \\ \end{bmatrix}K=5796810 -
Value (V):3 个值(
m=3
),每个值维度d_v=3
V=[101010110]V = \begin{bmatrix} 1 & 0 & 1 \\ 0 & 1 & 0 \\ 1 & 1 & 0 \\ \end{bmatrix}V=101011100
Step 1: 计算 Query 和 Key 的点积(Dot-Product)
计算 S=QKTS = QK^TS=QKT:
QKT=[1⋅5+2⋅61⋅7+2⋅81⋅9+2⋅103⋅5+4⋅63⋅7+4⋅83⋅9+4⋅10]=[5+127+169+2015+2421+3227+40]=[172329395367]QK^T = \begin{bmatrix} 1 \cdot 5 + 2 \cdot 6 & 1 \cdot 7 + 2 \cdot 8 & 1 \cdot 9 + 2 \cdot 10 \\ 3 \cdot 5 + 4 \cdot 6 & 3 \cdot 7 + 4 \cdot 8 & 3 \cdot 9 + 4 \cdot 10 \\ \end{bmatrix} = \begin{bmatrix} 5+12 & 7+16 & 9+20 \\ 15+24 & 21+32 & 27+40 \\ \end{bmatrix} = \begin{bmatrix} 17 & 23 & 29 \\ 39 & 53 & 67 \\ \end{bmatrix}QKT=[1⋅5+2⋅63⋅5+4⋅61⋅7+2⋅83⋅7+4⋅81⋅9+2⋅103⋅9+4⋅10]=[5+1215+247+1621+329+2027+40]=[173923532967]
Step 2: 缩放(Scaling)
除以 dk=2≈1.414\sqrt{d_k} = \sqrt{2} \approx 1.414dk=2≈1.414:
Sscaled=S2=[17/1.41423/1.41429/1.41439/1.41453/1.41467/1.414]≈[12.0216.2620.5127.5837.4847.38]S_{\text{scaled}} = \frac{S}{\sqrt{2}} = \begin{bmatrix} 17/1.414 & 23/1.414 & 29/1.414 \\ 39/1.414 & 53/1.414 & 67/1.414 \\ \end{bmatrix} \approx \begin{bmatrix} 12.02 & 16.26 & 20.51 \\ 27.58 & 37.48 & 47.38 \\ \end{bmatrix}Sscaled=2S=[17/1.41439/1.41423/1.41453/1.41429/1.41467/1.414]≈[12.0227.5816.2637.4820.5147.38]
Step 3: Softmax 归一化(计算注意力权重)
对每一行(每个 Query)做 softmax:
softmax([12.02,16.26,20.51])≈[2.06×10−4,0.016,0.984]\text{softmax}([12.02, 16.26, 20.51]) \approx [2.06 \times 10^{-4}, 0.016, 0.984]softmax([12.02,16.26,20.51])≈[2.06×10−4,0.016,0.984]
softmax([27.58,37.48,47.38])≈[1.67×10−9,0.0001,0.9999]\text{softmax}([27.58, 37.48, 47.38]) \approx [1.67 \times 10^{-9}, 0.0001, 0.9999]softmax([27.58,37.48,47.38])≈[1.67×10−9,0.0001,0.9999]
因此,注意力权重矩阵 AAA 为:
A≈[2.06×10−40.0160.9841.67×10−90.00010.9999]A \approx \begin{bmatrix} 2.06 \times 10^{-4} & 0.016 & 0.984 \\ 1.67 \times 10^{-9} & 0.0001 & 0.9999 \\ \end{bmatrix}A≈[2.06×10−41.67×10−90.0160.00010.9840.9999]
解释:
- 第 1 个 Query 主要关注第 3 个 Key(权重 0.984)。
- 第 2 个 Query 几乎只关注第 3 个 Key(权重 0.9999)。
Step 4: 加权求和(聚合 Value)
计算 Output=A⋅V\text{Output} = A \cdot VOutput=A⋅V:
Output=[2.06×10−4⋅1+0.016⋅0+0.984⋅12.06×10−4⋅0+0.016⋅1+0.984⋅12.06×10−4⋅1+0.016⋅0+0.984⋅0]T≈[0.9841.0000.0002]T\text{Output} = \begin{bmatrix} 2.06 \times 10^{-4} \cdot 1 + 0.016 \cdot 0 + 0.984 \cdot 1 \\ 2.06 \times 10^{-4} \cdot 0 + 0.016 \cdot 1 + 0.984 \cdot 1 \\ 2.06 \times 10^{-4} \cdot 1 + 0.016 \cdot 0 + 0.984 \cdot 0 \\ \end{bmatrix}^T \approx \begin{bmatrix} 0.984 \\ 1.000 \\ 0.0002 \\ \end{bmatrix}^TOutput=2.06×10−4⋅1+0.016⋅0+0.984⋅12.06×10−4⋅0+0.016⋅1+0.984⋅12.06×10−4⋅1+0.016⋅0+0.984⋅0T≈0.9841.0000.0002T
Output=[0.9841.0000.00020.99990.99990.0001]\text{Output} = \begin{bmatrix} 0.984 & 1.000 & 0.0002 \\ 0.9999 & 0.9999 & 0.0001 \\ \end{bmatrix}Output=[0.9840.99991.0000.99990.00020.0001]
解释:
- 第 1 行:主要聚合了第 3 个 Value
[1, 1, 0]
,但受前两个 Value 微弱影响。- 第 2 行:几乎完全由第 3 个 Value 决定。
最终输出
Output≈[0.9841.0000.00020.99990.99990.0001]\text{Output} \approx \begin{bmatrix} 0.984 & 1.000 & 0.0002 \\ 0.9999 & 0.9999 & 0.0001 \\ \end{bmatrix}Output≈[0.9840.99991.0000.99990.00020.0001]
总结
- 点积:计算 Query 和 Key 的相似度。
- 缩放:防止梯度爆炸/消失。
- Softmax:归一化为概率分布。
- 加权求和:聚合 Value 得到最终表示。
这个例子展示了 SDPA 如何动态分配注意力权重,并生成上下文感知的输出。实际应用中(如 Transformer),还会结合 多头注意力(Multi-Head Attention) 增强表达能力。