Scaled Dot-Product Attention(缩放点积注意力)详解
Scaled Dot-Product Attention(缩放点积注意力)是 Transformer 架构中的核心机制,由 Vaswani 等人在 2017 年《Attention Is All You Need》 论文中提出。它用于计算输入序列中不同位置之间的相关性,从而动态调整权重,使模型能够关注最重要的信息。
1. 核心公式
缩放点积注意力的计算过程如下:
Attention(Q,K,V)=softmax(QKTdk)VAttention(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dkQKT)V
其中:
- ( Q )(Query):查询向量,表示当前要计算的 token。
- ( K )(Key):键向量,用于计算与 Query 的相似度。
- ( V )(Value):值向量,存储实际的信息。
- ( dkd_kdk ):Key 的维度(用于缩放,防止 softmax 梯度消失)。
2. 计算步骤
(1)计算 Query 和 Key 的点积
S=QKTS = QK^TS=QKT
- 计算 Query 和 Key 的相似度(点积),得到 注意力分数矩阵 ( S )。
- 点积值越大,表示两个 token 越相关。
(2)缩放(Scaling)
Sscaled=SdkS_{\text{scaled}} = \frac{S}{\sqrt{d_k}}Sscaled=dkS
- 由于点积值可能过大,导致 softmax 梯度消失(梯度趋近于 0),因此除以 ( \sqrt{d_k} ) 进行缩放。
(3)Softmax 归一化
A=softmax(Sscaled)A = \text{softmax}(S_{\text{scaled}})A=softmax(Sscaled)
- 对缩放后的分数进行 softmax,得到 注意力权重矩阵 ( A )(概率分布,总和为 1)。
(4)加权求和 Value
Output=A⋅V\text{Output} = A \cdot VOutput=A⋅V
- 用注意力权重对 Value 进行加权求和,得到最终的输出。
直观例子
假设输入句子为 “The cat sat on the mat”:
- Query(Q):当前处理的词(如 “sat”)。
- Key(K):句子中所有词的表示(“The”, “cat”, …, “mat”)。
- Value(V):与Key对应的实际信息(如词向量)。
- 输出:通过权重计算,“sat” 可能对 “cat” 和 “mat” 分配高权重,生成包含上下文信息的表示。
3. 为什么需要缩放?
- 防止梯度消失:如果 ( d_k ) 很大(如 512 或 1024),点积值可能非常大,导致 softmax 进入饱和区(接近 0 或 1),梯度变得极小,影响训练。
- 稳定训练:缩放后,softmax 输入保持在合理范围,使梯度更稳定。
4. 代码实现(PyTorch)
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: [batch_size, num_heads, seq_len, d_k]
K: [batch_size, num_heads, seq_len, d_k]
V: [batch_size, num_heads, seq_len, d_v]
mask: [batch_size, seq_len, seq_len] (可选)
"""
d_k = Q.size(-1) # Key 的维度
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5) # [batch, heads, seq_len, seq_len]
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9) # 掩码处理(如 Transformer 的解码器)
attn_weights = F.softmax(scores, dim=-1) # 计算注意力权重
output = torch.matmul(attn_weights, V) # 加权求和
return output, attn_weights
5. 应用场景
- Transformer 的自注意力(Self-Attention):
- 用于计算输入序列内部的关系(如 BERT、GPT)。
- 交叉注意力(Cross-Attention):
- 用于不同模态之间的交互(如文本-图像生成,Stable Diffusion 中的 U-Net 使用)。
- 多模态模型(如 CLIP、Flamingo):
- 计算文本和图像之间的相关性。
6. 与普通点积注意力的区别
对比项 | 普通点积注意力 | 缩放点积注意力 |
---|---|---|
计算方式 | (QKTQK^TQKT ) | ( QKTdk\frac{QK^T}{\sqrt{d_k}}dkQKT) |
梯度问题 | 容易梯度消失 | 稳定训练 |
适用场景 | 小维度 Key | 大维度 Key(如 Transformer) |
Transformer
7. 在Transformer中的应用
Scaled Dot-Product Attention(缩放点积注意力)是Transformer架构的核心组件:
- 基础计算单元:Transformer中的自注意力(Self-Attention)和编码器-解码器注意力(Encoder-Decoder Attention)均基于Scaled Dot-Product Attention实现
- 多头注意力(Multi-Head Attention):
Transformer通过并行多个Scaled Dot-Product Attention(即多头机制),从不同子空间捕捉信息,增强模型表达能力
。每个头的计算独立,最终拼接并线性变换 - 编码器与解码器的关键模块:
- 编码器:每层包含一个多头自注意力(Self-Attention),用于捕捉输入序列内部的依赖关系。
- 解码器:除自注意力外,还增加编码器-解码器注意力层,使解码器能动态关注编码器输出的关键信息。
8. 总结
- 作用:计算输入序列的权重分布,使模型能动态关注重要部分。
- 关键改进:缩放因子 ( \sqrt{d_k} ) 防止 softmax 饱和。
- 应用:几乎所有现代 Transformer 模型(如 GPT、BERT、Stable Diffusion)的核心组件。
如果你在实现 Transformer 或阅读相关论文时遇到它,现在应该能清晰理解它的原理了!
Transformer