SDPA:Scaled Dot-Product Attention(缩放点积注意力)

Scaled Dot-Product Attention(缩放点积注意力)详解

Scaled Dot-Product Attention(缩放点积注意力)是 Transformer 架构中的核心机制,由 Vaswani 等人在 2017 年《Attention Is All You Need》 论文中提出。它用于计算输入序列中不同位置之间的相关性,从而动态调整权重,使模型能够关注最重要的信息。


1. 核心公式

缩放点积注意力的计算过程如下:

Attention(Q,K,V)=softmax(QKTdk)VAttention(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)VAttention(Q,K,V)=softmax(dkQKT)V

其中:

  • ( Q )(Query):查询向量,表示当前要计算的 token。
  • ( K )(Key):键向量,用于计算与 Query 的相似度。
  • ( V )(Value):值向量,存储实际的信息。
  • ( dkd_kdk ):Key 的维度(用于缩放,防止 softmax 梯度消失)。

2. 计算步骤

(1)计算 Query 和 Key 的点积

S=QKTS = QK^TS=QKT

  • 计算 Query 和 Key 的相似度(点积),得到 注意力分数矩阵 ( S )
  • 点积值越大,表示两个 token 越相关。

(2)缩放(Scaling)

Sscaled=SdkS_{\text{scaled}} = \frac{S}{\sqrt{d_k}}Sscaled=dkS

  • 由于点积值可能过大,导致 softmax 梯度消失(梯度趋近于 0),因此除以 ( \sqrt{d_k} ) 进行缩放。

(3)Softmax 归一化

A=softmax(Sscaled)A = \text{softmax}(S_{\text{scaled}})A=softmax(Sscaled)

  • 对缩放后的分数进行 softmax,得到 注意力权重矩阵 ( A )(概率分布,总和为 1)。

(4)加权求和 Value

Output=A⋅V\text{Output} = A \cdot VOutput=AV

  • 用注意力权重对 Value 进行加权求和,得到最终的输出。

直观例子
假设输入句子为 “The cat sat on the mat”:

  1. Query(Q):当前处理的词(如 “sat”)。
  2. Key(K):句子中所有词的表示(“The”, “cat”, …, “mat”)。
  3. Value(V):与Key对应的实际信息(如词向量)。
  4. 输出:通过权重计算,“sat” 可能对 “cat” 和 “mat” 分配高权重,生成包含上下文信息的表示。

3. 为什么需要缩放?

  • 防止梯度消失:如果 ( d_k ) 很大(如 512 或 1024),点积值可能非常大,导致 softmax 进入饱和区(接近 0 或 1),梯度变得极小,影响训练。
  • 稳定训练:缩放后,softmax 输入保持在合理范围,使梯度更稳定。

4. 代码实现(PyTorch)

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q: [batch_size, num_heads, seq_len, d_k]
    K: [batch_size, num_heads, seq_len, d_k]
    V: [batch_size, num_heads, seq_len, d_v]
    mask: [batch_size, seq_len, seq_len] (可选)
    """
    d_k = Q.size(-1)  # Key 的维度
    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)  # [batch, heads, seq_len, seq_len]
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)  # 掩码处理(如 Transformer 的解码器)
    
    attn_weights = F.softmax(scores, dim=-1)  # 计算注意力权重
    output = torch.matmul(attn_weights, V)  # 加权求和
    
    return output, attn_weights

5. 应用场景

  1. Transformer 的自注意力(Self-Attention)
    • 用于计算输入序列内部的关系(如 BERT、GPT)。
  2. 交叉注意力(Cross-Attention)
    • 用于不同模态之间的交互(如文本-图像生成,Stable Diffusion 中的 U-Net 使用)。
  3. 多模态模型(如 CLIP、Flamingo)
    • 计算文本和图像之间的相关性。

6. 与普通点积注意力的区别

对比项普通点积注意力缩放点积注意力
计算方式(QKTQK^TQKT )( QKTdk\frac{QK^T}{\sqrt{d_k}}dkQKT)
梯度问题容易梯度消失稳定训练
适用场景小维度 Key大维度 Key(如 Transformer)

Transformer

7. 在Transformer中的应用

Scaled Dot-Product Attention(缩放点积注意力)是Transformer架构的核心组件:

  • ​基础计算单元​:Transformer中的自注意力(Self-Attention)和编码器-解码器注意力(Encoder-Decoder Attention)均基于Scaled Dot-Product Attention实现
  • 多头注意力(Multi-Head Attention)​​:
    Transformer通过并行多个Scaled Dot-Product Attention(即多头机制),从不同子空间捕捉信息,增强模型表达能力
    。每个头的计算独立,最终拼接并线性变换
  • ​编码器与解码器的关键模块​:
    • 编码器​:每层包含一个多头自注意力(Self-Attention),用于捕捉输入序列内部的依赖关系。
    • ​解码器​:除自注意力外,还增加编码器-解码器注意力层,使解码器能动态关注编码器输出的关键信息。

8. 总结

  • 作用:计算输入序列的权重分布,使模型能动态关注重要部分。
  • 关键改进:缩放因子 ( \sqrt{d_k} ) 防止 softmax 饱和。
  • 应用:几乎所有现代 Transformer 模型(如 GPT、BERT、Stable Diffusion)的核心组件。

如果你在实现 Transformer 或阅读相关论文时遇到它,现在应该能清晰理解它的原理了!

Transformer

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值