在Transformer架构的工程优化中,注意力机制的计算效率是核心瓶颈之一。标准的缩放点积注意力(Scaled Dot-Product Attention)存在 O(T²d) 的时间复杂度和内存占用问题——当序列长度T超过1k时,显存消耗会急剧增加,甚至导致训练中断。为解决这一问题,FlashAttention-v2通过分块计算和LogSumExp数值优化,在保持精度的前提下,将显存占用降低至O(Td),同时通过硬件感知优化提升计算速度。
本文基于Stanford CS336作业2要求,详细拆解FlashAttention-v2的两种实现方案:纯PyTorch分块版本(理解核心逻辑)和Triton内核加速版本(工业级性能),并对比分析其设计思路与性能优势。
一、FlashAttention-v2核心原理回顾
在深入代码前,需先明确FlashAttention-v2解决的核心痛点与关键优化手段:
1.1 标准注意力的痛点
标准注意力计算流程为:
- 计算注意力分数矩阵 ( S = QK^T / \sqrt{d_k} )(形状:( B \times T_q \times T_k ))
- 应用掩码(如因果掩码)后计算Softmax:( P = \text{Softmax}(S) )
- 加权求和得到输出:( O = PV )
问题在于:当 ( T_q = T_k = 2048 ) 时,( S ) 和 ( P ) 的形状为 ( B \times 2048 \times 2048 ),单个float32矩阵就需占用 ( 2048 \times 2048 \times 4 \approx 16MB ),若 batch_size=32,则仅注意力矩阵就需占用 ( 32 \times 16MB = 512MB )——而实际场景中序列长度常达4k、8k,显存消耗会呈平方级增长。
1.2 FlashAttention-v2的核心优化
FlashAttention-v2通过分块计算(Tile-based Computation)和LogSumExp数值稳定技巧,将“一次性计算全量矩阵”改为“逐块计算并累积结果”,核心思路如下:
- 分块策略:将 ( Q )(( T_q \times d_k ))按行分成多个Query块(( B_q \times d_k )),将 ( K )(( T_k \times d_k ))和 ( V )(( T_k \times d_v ))按列分成多个Key-Value块(( B_k \times d_k ) 和 ( B_k \times d_v ))。
- 逐块累积:对每个Query块,循环遍历所有Key-Value块,计算局部注意力分数并累积到输出 ( O ) 中,全程不存储完整的 ( S ) 和 ( P ) 矩阵。
- LogSumExp优化:为避免分块Softmax的精度损失,使用LogSumExp公式累积概率权重,保证全局Softmax结果与标准计算一致。
二、纯PyTorch实现:FlashAttenTorch
首先实现纯PyTorch版本的FlashAttention(FlashAttenTorch),该版本不依赖任何底层加速框架,仅通过分块逻辑展示FlashAttention的核心流程,便于理解原理。
2.1 类结构与前向传播
FlashAttenTorch 继承自 torch.autograd.Function,需自定义 forward(前向计算)和 backward(反向梯度)方法。
2.1.1 前向传播(Forward)
前向传播的核心是“分块遍历Query和Key-Value,累积输出 ( O ) 和LogSumExp中间结果 ( L )”,步骤如下:
class FlashAttenTorch(torch.autograd.Function):
@staticmethod
def forward(ctx, Q, K, V, is_causal=False, Q_TILE_SIZE=16, K_TILE_SIZE=16):
"""
输入:
Q: [B, Tq, dk] → Query矩阵
K: [B, Tk, dk] → Key矩阵
V: [B, Tk, dv] → Value矩阵
is_causal: 是否启用因果掩码(防止关注未来token)
Q_TILE_SIZE: Query分块大小(Bq)
K_TILE_SIZE: Key-Value分块大小(Bk)
输出:
O: [B, Tq, dv] → 注意力输出
"""
B, Tq, dk = Q.shape
Tk = K.size(1)
dv = V.size(2)
scale = 1.0 / (dk ** 0.5) # 注意力缩放因子
# 初始化输出O和LogSumExp中间结果L
O = torch.zeros(B, Tq, dv, device=Q.device, dtype=Q.dtype)
L = torch.zeros(B, Tq, device=Q.device, dtype=Q.dtype)
# 1. 遍历所有Query块(按Q_TILE_SIZE分块)
for q_start in range(0, Tq, Q_TILE_SIZE):
q_end = min(q_start + Q_TILE_SIZE, Tq)
Qi = Q[:, q_start:q_end, :] # 当前Query块:[B, Bq, dk]
current_q_size = q_end - q_start
# 初始化当前Query块的最大值(用于LogSumExp)
pre_mx = torch.full((B, current_q_size), float('-inf'), device=Q.device, dtype=Q.dtype)
# 因果掩码需用到的Query位置索引
if is_causal:
q_pos = torch.arange(q_start, q_end, device=Q.device) # [Bq]
# 2. 遍历所有Key-Value块(按K_TILE_SIZE分块)
for k_start in range(0, Tk, K_TILE_SIZE):
k_end = min(k_start + K_TILE_SIZE, Tk)
Kj = K[:, k_start:k_end, :] # 当前Key块:[B, Bk, dk]
Vj = V[:, k_start:k_end, :] # 当前Value块:[B, Bk, dv]
# 3. 计算局部注意力分数 Sij = Qi @ Kj^T / sqrt(dk)
Sij = einsum(Qi, Kj, "... Bq dk, ... Bk dk -> ... Bq Bk") * scale # [B, Bq, Bk]
# 4. 应用因果掩码(仅当前Query块能关注之前的Key块)
if is_causal:
k_pos = torch.arange(k_start, k_end, device=Q.device) # [Bk]
mask = q_pos[:, None] >= k_pos[None, :] # [Bq, Bk]:True表示可关注
Sij = torch.where(mask, Sij, torch.tensor(float('-inf'), device=Sij.device))
# 5. LogSumExp累积:更新最大值和权重和
current_mx = torch.max(Sij, dim=-1).values # [B, Bq]:当前Key块的Sij最大值
mx = torch.max(pre_mx, current_mx) # [B, Bq]:累积最大值
# 计算局部概率权重(指数归一化)
Pij = torch.exp(Sij - mx.unsqueeze(-1)) # [B, Bq, Bk]
# 累积LogSumExp的权重和 L(对应全局Softmax的分母)
L[:, q_start:q_end] = torch.exp(pre_mx - mx) * L[:, q_start:q_end] + torch.sum(Pij, dim=-1)
# 累积输出 O(对应全局 PV 的部分和)
O[:, q_start:q_end, :] = (torch.exp(pre_mx - mx).unsqueeze(-1) * O[:, q_start:q_end, :]
+ einsum(Pij, Vj, "... Bq Bk, ... Bk dv -> ... Bq dv"))
# 更新前一轮最大值,准备下一个Key块
pre_mx = mx
# 6. 归一化当前Query块的输出(全局Softmax的最终结果)
O[:, q_start:q_end, :] /= L[:, q_start:q_end].unsqueeze(-1)
# 更新L为全局LogSumExp结果(用于反向传播)
L[:, q_start:q_end] = mx + torch.log(L[:, q_start:q_end])
# 保存反向传播所需的中间变量
ctx.save_for_backward(Q, K, V, O, L)
ctx.is_causal = is_causal
return O
2.1.2 反向传播(Backward)
反向传播需计算梯度 ( dQ, dK, dV ),核心是基于前向保存的 ( O, L ) 推导局部梯度并累积。这里采用PyTorch编译加速(torch.compile)提升反向计算效率:
@staticmethod
def backward(ctx, grad_out):
"""
输入:
grad_out: [B, Tq, dv] → 输出O的梯度
输出:
dQ: [B, Tq, dk] → Q的梯度
dK: [B, Tk, dk] → K的梯度
dV: [B, Tk, dv] → V的梯度
"""
Q, K, V, O, L = ctx.saved_tensors
is_causal = ctx.is_causal
# 调用预编译的反向计算函数
dQ, dK, dV, _ = compiled_flash_bwd(Q, K, V, O, L, grad_out, is_causal)
return dQ, dK, dV, None # 后两个None对应is_causal和TileSize的梯度(无需计算)
# 预编译反向计算函数,提升效率
def flash_bwd(Q, K, V, O, L, dO, is_causal=False):
B, Tq, dk = Q.shape
Tk = K.size(1)
scale = 1.0 / (dk ** 0.5)
# 1. 计算中间变量 D = O · dO^T(用于梯度链式法则)
D = torch.sum(O * dO, dim=-1, keepdim=True) # [B, Tq, 1]
# 2. 重构注意力分数 S(基于前向保存的L)
S = torch.matmul(Q, K.transpose(-1, -2)) * scale # [B, Tq, Tk]
if is_causal:
mask = torch.triu(torch.ones(Tq, Tk, device=Q.device, dtype=torch.bool), diagonal=1)
S = S

最低0.47元/天 解锁文章
927

被折叠的 条评论
为什么被折叠?



