XAttention 计算步骤详解及示例

XAttention 计算步骤详解及示例

XAttention 是一种高效的块稀疏注意力机制,通过 反对角线评分(Antidiagonal Scoring)动态阈值选择 来优化长序列 Transformer 模型的推理效率。以下是其核心计算步骤及具体示例。


1. XAttention 的核心步骤

Step 1: 计算原始注意力分数

输入:

  • Query Q∈Rn×dQ \in \mathbb{R}^{n \times d}QRn×d
  • Key K∈Rm×dK \in \mathbb{R}^{m \times d}KRm×d
  • Value V∈Rm×dvV \in \mathbb{R}^{m \times d_v}VRm×dv

计算未缩放的注意力分数:
S=QKTS = QK^TS=QKT

Step 2: 反对角线评分(Antidiagonal Scoring)

  1. 分块计算:将 SSS 划分为 B×BB \times BB×B 的块(如 8×88 \times 88×8)。
  2. 反对角线求和:对每个块,计算反对角线(从左下到右上)元素的和,作为块的重要性分数:
    Score=∑i+j=kAi,j\text{Score} = \sum_{i+j=k} A_{i,j}Score=i+j=kAi,j
    • 其中 kkk 是反对角线索引,例如 k=0,1,...,2B−2k=0,1,...,2B-2k=0,1,...,2B2

Step 3: 阈值块选择

  1. 归一化:对块分数进行 softmax 归一化:
    P=softmax(Score)P = \text{softmax}(\text{Score})P=softmax(Score)
  2. 选择关键块:保留累积概率超过阈值 τ\tauτ 的最小块集合 B∗B^*B
    B∗=arg⁡min⁡∣B∣s.t.∑(i,j)∈BPi,j>τB^* = \arg \min |B| \quad \text{s.t.} \quad \sum_{(i,j) \in B} P_{i,j} > \tauB=argminBs.t.(i,j)BPi,j>τ

Step 4: 稀疏注意力计算

仅计算选中的关键块 B∗B^*B 的注意力权重,并加权聚合 VVV
Output=∑(i,j)∈B∗Ai,jVj\text{Output} = \sum_{(i,j) \in B^*} A_{i,j} V_jOutput=(i,j)BAi,jVj


2. 计算示例

输入数据

假设 d=2d=2d=2,输入如下:

  • Query (Q)
    Q=[1.02.03.04.0]Q = \begin{bmatrix} 1.0 & 2.0 \\ 3.0 & 4.0 \\ \end{bmatrix}Q=[1.03.02.04.0]
  • Key (K)
    K=[5.06.07.08.09.010.0]K = \begin{bmatrix} 5.0 & 6.0 \\ 7.0 & 8.0 \\ 9.0 & 10.0 \\ \end{bmatrix}K=5.07.09.06.08.010.0
  • Value (V)
    V=[1.00.01.00.01.00.01.01.00.0]V = \begin{bmatrix} 1.0 & 0.0 & 1.0 \\ 0.0 & 1.0 & 0.0 \\ 1.0 & 1.0 & 0.0 \\ \end{bmatrix}V=1.00.01.00.01.01.01.00.00.0

Step 1: 计算原始注意力分数 S=QKTS = QK^TS=QKT

S=[1⋅5+2⋅61⋅7+2⋅81⋅9+2⋅103⋅5+4⋅63⋅7+4⋅83⋅9+4⋅10]=[172329395367]S = \begin{bmatrix} 1 \cdot 5 + 2 \cdot 6 & 1 \cdot 7 + 2 \cdot 8 & 1 \cdot 9 + 2 \cdot 10 \\ 3 \cdot 5 + 4 \cdot 6 & 3 \cdot 7 + 4 \cdot 8 & 3 \cdot 9 + 4 \cdot 10 \\ \end{bmatrix} = \begin{bmatrix} 17 & 23 & 29 \\ 39 & 53 & 67 \\ \end{bmatrix}S=[15+2635+4617+2837+4819+21039+410]=[173923532967]

Step 2: 反对角线评分(假设块大小 2×22 \times 22×2

  • 块 1S1:2,1:2S_{1:2,1:2}S1:2,1:2):
    反对角线元素={17,53}和=17+53=70\text{反对角线元素} = \{17, 53\} \quad \text{和} = 17 + 53 = 70反对角线元素={17,53}=17+53=70
  • 块 2S1:2,2:3S_{1:2,2:3}S1:2,2:3):
    反对角线元素={23,67}和=23+67=90\text{反对角线元素} = \{23, 67\} \quad \text{和} = 23 + 67 = 90反对角线元素={23,67}=23+67=90

Step 3: 阈值块选择(假设 τ=0.6\tau = 0.6τ=0.6

  1. 归一化
    P=softmax([70,90])≈[0.27,0.73]P = \text{softmax}([70, 90]) \approx [0.27, 0.73]P=softmax([70,90])[0.27,0.73]
  2. 选择关键块
    • 累积概率:0.27+0.73=1.0>τ0.27 + 0.73 = 1.0 > \tau0.27+0.73=1.0>τ,因此选择 块 2S1:2,2:3S_{1:2,2:3}S1:2,2:3)。

Step 4: 稀疏注意力计算

仅计算块 2 的注意力权重:
Aselected=softmax([23,67]2)≈[0.0001,0.9999]A_{\text{selected}} = \text{softmax}\left(\frac{[23, 67]}{\sqrt{2}}\right) \approx [0.0001, 0.9999]Aselected=softmax(2[23,67])[0.0001,0.9999]
加权聚合 VVV
Output=0.0001⋅[0.0,1.0,0.0]+0.9999⋅[1.0,1.0,0.0]≈[1.0,1.0,0.0]\text{Output} = 0.0001 \cdot [0.0, 1.0, 0.0] + 0.9999 \cdot [1.0, 1.0, 0.0] \approx [1.0, 1.0, 0.0]Output=0.0001[0.0,1.0,0.0]+0.9999[1.0,1.0,0.0][1.0,1.0,0.0]


3. 关键优势

  1. 计算高效:反对角线评分的计算复杂度低,仅需 O(B2)O(B^2)O(B2) 而非 O(N2)O(N^2)O(N2)
  2. 模式保留:反对角线能捕捉垂直/斜线依赖(如视频帧间的时空关联)。
  3. 动态适应性:通过阈值调整可平衡计算量与精度。

XAttention 在长文本和视频任务中可加速 13.5 倍,同时保持全注意力的精度。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值