XAttention 计算步骤详解及示例
XAttention 是一种高效的块稀疏注意力机制,通过 反对角线评分(Antidiagonal Scoring) 和 动态阈值选择 来优化长序列 Transformer 模型的推理效率。以下是其核心计算步骤及具体示例。
1. XAttention 的核心步骤
Step 1: 计算原始注意力分数
输入:
- Query Q∈Rn×dQ \in \mathbb{R}^{n \times d}Q∈Rn×d
- Key K∈Rm×dK \in \mathbb{R}^{m \times d}K∈Rm×d
- Value V∈Rm×dvV \in \mathbb{R}^{m \times d_v}V∈Rm×dv
计算未缩放的注意力分数:
S=QKTS = QK^TS=QKT
Step 2: 反对角线评分(Antidiagonal Scoring)
- 分块计算:将 SSS 划分为 B×BB \times BB×B 的块(如 8×88 \times 88×8)。
- 反对角线求和:对每个块,计算反对角线(从左下到右上)元素的和,作为块的重要性分数:
Score=∑i+j=kAi,j\text{Score} = \sum_{i+j=k} A_{i,j}Score=∑i+j=kAi,j- 其中 kkk 是反对角线索引,例如 k=0,1,...,2B−2k=0,1,...,2B-2k=0,1,...,2B−2。
Step 3: 阈值块选择
- 归一化:对块分数进行 softmax 归一化:
P=softmax(Score)P = \text{softmax}(\text{Score})P=softmax(Score) - 选择关键块:保留累积概率超过阈值 τ\tauτ 的最小块集合 B∗B^*B∗:
B∗=argmin∣B∣s.t.∑(i,j)∈BPi,j>τB^* = \arg \min |B| \quad \text{s.t.} \quad \sum_{(i,j) \in B} P_{i,j} > \tauB∗=argmin∣B∣s.t.∑(i,j)∈BPi,j>τ
Step 4: 稀疏注意力计算
仅计算选中的关键块 B∗B^*B∗ 的注意力权重,并加权聚合 VVV:
Output=∑(i,j)∈B∗Ai,jVj\text{Output} = \sum_{(i,j) \in B^*} A_{i,j} V_jOutput=∑(i,j)∈B∗Ai,jVj
2. 计算示例
输入数据
假设 d=2d=2d=2,输入如下:
- Query (Q):
Q=[1.02.03.04.0]Q = \begin{bmatrix} 1.0 & 2.0 \\ 3.0 & 4.0 \\ \end{bmatrix}Q=[1.03.02.04.0] - Key (K):
K=[5.06.07.08.09.010.0]K = \begin{bmatrix} 5.0 & 6.0 \\ 7.0 & 8.0 \\ 9.0 & 10.0 \\ \end{bmatrix}K=5.07.09.06.08.010.0 - Value (V):
V=[1.00.01.00.01.00.01.01.00.0]V = \begin{bmatrix} 1.0 & 0.0 & 1.0 \\ 0.0 & 1.0 & 0.0 \\ 1.0 & 1.0 & 0.0 \\ \end{bmatrix}V=1.00.01.00.01.01.01.00.00.0
Step 1: 计算原始注意力分数 S=QKTS = QK^TS=QKT
S=[1⋅5+2⋅61⋅7+2⋅81⋅9+2⋅103⋅5+4⋅63⋅7+4⋅83⋅9+4⋅10]=[172329395367]S = \begin{bmatrix} 1 \cdot 5 + 2 \cdot 6 & 1 \cdot 7 + 2 \cdot 8 & 1 \cdot 9 + 2 \cdot 10 \\ 3 \cdot 5 + 4 \cdot 6 & 3 \cdot 7 + 4 \cdot 8 & 3 \cdot 9 + 4 \cdot 10 \\ \end{bmatrix} = \begin{bmatrix} 17 & 23 & 29 \\ 39 & 53 & 67 \\ \end{bmatrix}S=[1⋅5+2⋅63⋅5+4⋅61⋅7+2⋅83⋅7+4⋅81⋅9+2⋅103⋅9+4⋅10]=[173923532967]
Step 2: 反对角线评分(假设块大小 2×22 \times 22×2)
- 块 1(S1:2,1:2S_{1:2,1:2}S1:2,1:2):
反对角线元素={17,53}和=17+53=70\text{反对角线元素} = \{17, 53\} \quad \text{和} = 17 + 53 = 70反对角线元素={17,53}和=17+53=70 - 块 2(S1:2,2:3S_{1:2,2:3}S1:2,2:3):
反对角线元素={23,67}和=23+67=90\text{反对角线元素} = \{23, 67\} \quad \text{和} = 23 + 67 = 90反对角线元素={23,67}和=23+67=90
Step 3: 阈值块选择(假设 τ=0.6\tau = 0.6τ=0.6)
- 归一化:
P=softmax([70,90])≈[0.27,0.73]P = \text{softmax}([70, 90]) \approx [0.27, 0.73]P=softmax([70,90])≈[0.27,0.73] - 选择关键块:
- 累积概率:0.27+0.73=1.0>τ0.27 + 0.73 = 1.0 > \tau0.27+0.73=1.0>τ,因此选择 块 2(S1:2,2:3S_{1:2,2:3}S1:2,2:3)。
Step 4: 稀疏注意力计算
仅计算块 2 的注意力权重:
Aselected=softmax([23,67]2)≈[0.0001,0.9999]A_{\text{selected}} = \text{softmax}\left(\frac{[23, 67]}{\sqrt{2}}\right) \approx [0.0001, 0.9999]Aselected=softmax(2[23,67])≈[0.0001,0.9999]
加权聚合 VVV:
Output=0.0001⋅[0.0,1.0,0.0]+0.9999⋅[1.0,1.0,0.0]≈[1.0,1.0,0.0]\text{Output} = 0.0001 \cdot [0.0, 1.0, 0.0] + 0.9999 \cdot [1.0, 1.0, 0.0] \approx [1.0, 1.0, 0.0]Output=0.0001⋅[0.0,1.0,0.0]+0.9999⋅[1.0,1.0,0.0]≈[1.0,1.0,0.0]
3. 关键优势
- 计算高效:反对角线评分的计算复杂度低,仅需 O(B2)O(B^2)O(B2) 而非 O(N2)O(N^2)O(N2)。
- 模式保留:反对角线能捕捉垂直/斜线依赖(如视频帧间的时空关联)。
- 动态适应性:通过阈值调整可平衡计算量与精度。
XAttention 在长文本和视频任务中可加速 13.5 倍,同时保持全注意力的精度。