Stanford CS336 | Assignment 2 - FlashAttention-v2 Pytorch & Triotn实现

在Transformer架构的工程优化中,注意力机制的计算效率是核心瓶颈之一。标准的缩放点积注意力(Scaled Dot-Product Attention)存在 O(T²d) 的时间复杂度和内存占用问题——当序列长度T超过1k时,显存消耗会急剧增加,甚至导致训练中断。为解决这一问题,FlashAttention-v2通过分块计算LogSumExp数值优化,在保持精度的前提下,将显存占用降低至O(Td),同时通过硬件感知优化提升计算速度。

本文基于Stanford CS336作业2要求,详细拆解FlashAttention-v2的两种实现方案:纯PyTorch分块版本(理解核心逻辑)和Triton内核加速版本(工业级性能),并对比分析其设计思路与性能优势。

一、FlashAttention-v2核心原理回顾

在深入代码前,需先明确FlashAttention-v2解决的核心痛点与关键优化手段:

1.1 标准注意力的痛点

标准注意力计算流程为:

  1. 计算注意力分数矩阵 ( S = QK^T / \sqrt{d_k} )(形状:( B \times T_q \times T_k ))
  2. 应用掩码(如因果掩码)后计算Softmax:( P = \text{Softmax}(S) )
  3. 加权求和得到输出:( O = PV )

问题在于:当 ( T_q = T_k = 2048 ) 时,( S ) 和 ( P ) 的形状为 ( B \times 2048 \times 2048 ),单个float32矩阵就需占用 ( 2048 \times 2048 \times 4 \approx 16MB ),若 batch_size=32,则仅注意力矩阵就需占用 ( 32 \times 16MB = 512MB )——而实际场景中序列长度常达4k、8k,显存消耗会呈平方级增长。

1.2 FlashAttention-v2的核心优化

FlashAttention-v2通过分块计算(Tile-based Computation)和LogSumExp数值稳定技巧,将“一次性计算全量矩阵”改为“逐块计算并累积结果”,核心思路如下:

  1. 分块策略:将 ( Q )(( T_q \times d_k ))按行分成多个Query块(( B_q \times d_k )),将 ( K )(( T_k \times d_k ))和 ( V )(( T_k \times d_v ))按列分成多个Key-Value块(( B_k \times d_k ) 和 ( B_k \times d_v ))。
  2. 逐块累积:对每个Query块,循环遍历所有Key-Value块,计算局部注意力分数并累积到输出 ( O ) 中,全程不存储完整的 ( S ) 和 ( P ) 矩阵。
  3. LogSumExp优化:为避免分块Softmax的精度损失,使用LogSumExp公式累积概率权重,保证全局Softmax结果与标准计算一致。

二、纯PyTorch实现:FlashAttenTorch

首先实现纯PyTorch版本的FlashAttention(FlashAttenTorch),该版本不依赖任何底层加速框架,仅通过分块逻辑展示FlashAttention的核心流程,便于理解原理。

2.1 类结构与前向传播

FlashAttenTorch 继承自 torch.autograd.Function,需自定义 forward(前向计算)和 backward(反向梯度)方法。

2.1.1 前向传播(Forward)

前向传播的核心是“分块遍历Query和Key-Value,累积输出 ( O ) 和LogSumExp中间结果 ( L )”,步骤如下:

class FlashAttenTorch(torch.autograd.Function):
    @staticmethod
    def forward(ctx, Q, K, V, is_causal=False, Q_TILE_SIZE=16, K_TILE_SIZE=16):
        """
        输入:
            Q: [B, Tq, dk] → Query矩阵
            K: [B, Tk, dk] → Key矩阵
            V: [B, Tk, dv] → Value矩阵
            is_causal: 是否启用因果掩码(防止关注未来token)
            Q_TILE_SIZE: Query分块大小(Bq)
            K_TILE_SIZE: Key-Value分块大小(Bk)
        输出:
            O: [B, Tq, dv] → 注意力输出
        """
        B, Tq, dk = Q.shape
        Tk = K.size(1)
        dv = V.size(2)
        scale = 1.0 / (dk ** 0.5)  # 注意力缩放因子

        # 初始化输出O和LogSumExp中间结果L
        O = torch.zeros(B, Tq, dv, device=Q.device, dtype=Q.dtype)
        L = torch.zeros(B, Tq, device=Q.device, dtype=Q.dtype)

        # 1. 遍历所有Query块(按Q_TILE_SIZE分块)
        for q_start in range(0, Tq, Q_TILE_SIZE):
            q_end = min(q_start + Q_TILE_SIZE, Tq)
            Qi = Q[:, q_start:q_end, :]  # 当前Query块:[B, Bq, dk]
            current_q_size = q_end - q_start

            # 初始化当前Query块的最大值(用于LogSumExp)
            pre_mx = torch.full((B, current_q_size), float('-inf'), device=Q.device, dtype=Q.dtype)

            # 因果掩码需用到的Query位置索引
            if is_causal:
                q_pos = torch.arange(q_start, q_end, device=Q.device)  # [Bq]

            # 2. 遍历所有Key-Value块(按K_TILE_SIZE分块)
            for k_start in range(0, Tk, K_TILE_SIZE):
                k_end = min(k_start + K_TILE_SIZE, Tk)
                Kj = K[:, k_start:k_end, :]  # 当前Key块:[B, Bk, dk]
                Vj = V[:, k_start:k_end, :]  # 当前Value块:[B, Bk, dv]

                # 3. 计算局部注意力分数 Sij = Qi @ Kj^T / sqrt(dk)
                Sij = einsum(Qi, Kj, "... Bq dk, ... Bk dk -> ... Bq Bk") * scale  # [B, Bq, Bk]

                # 4. 应用因果掩码(仅当前Query块能关注之前的Key块)
                if is_causal:
                    k_pos = torch.arange(k_start, k_end, device=Q.device)  # [Bk]
                    mask = q_pos[:, None] >= k_pos[None, :]  # [Bq, Bk]:True表示可关注
                    Sij = torch.where(mask, Sij, torch.tensor(float('-inf'), device=Sij.device))

                # 5. LogSumExp累积:更新最大值和权重和
                current_mx = torch.max(Sij, dim=-1).values  # [B, Bq]:当前Key块的Sij最大值
                mx = torch.max(pre_mx, current_mx)  # [B, Bq]:累积最大值

                # 计算局部概率权重(指数归一化)
                Pij = torch.exp(Sij - mx.unsqueeze(-1))  # [B, Bq, Bk]

                # 累积LogSumExp的权重和 L(对应全局Softmax的分母)
                L[:, q_start:q_end] = torch.exp(pre_mx - mx) * L[:, q_start:q_end] + torch.sum(Pij, dim=-1)

                # 累积输出 O(对应全局 PV 的部分和)
                O[:, q_start:q_end, :] = (torch.exp(pre_mx - mx).unsqueeze(-1) * O[:, q_start:q_end, :] 
                                        + einsum(Pij, Vj, "... Bq Bk, ... Bk dv -> ... Bq dv"))

                # 更新前一轮最大值,准备下一个Key块
                pre_mx = mx

            # 6. 归一化当前Query块的输出(全局Softmax的最终结果)
            O[:, q_start:q_end, :] /= L[:, q_start:q_end].unsqueeze(-1)
            # 更新L为全局LogSumExp结果(用于反向传播)
            L[:, q_start:q_end] = mx + torch.log(L[:, q_start:q_end])

        # 保存反向传播所需的中间变量
        ctx.save_for_backward(Q, K, V, O, L)
        ctx.is_causal = is_causal
        return O
2.1.2 反向传播(Backward)

反向传播需计算梯度 ( dQ, dK, dV ),核心是基于前向保存的 ( O, L ) 推导局部梯度并累积。这里采用PyTorch编译加速(torch.compile)提升反向计算效率:

    @staticmethod
    def backward(ctx, grad_out):
        """
        输入:
            grad_out: [B, Tq, dv] → 输出O的梯度
        输出:
            dQ: [B, Tq, dk] → Q的梯度
            dK: [B, Tk, dk] → K的梯度
            dV: [B, Tk, dv] → V的梯度
        """
        Q, K, V, O, L = ctx.saved_tensors
        is_causal = ctx.is_causal

        # 调用预编译的反向计算函数
        dQ, dK, dV, _ = compiled_flash_bwd(Q, K, V, O, L, grad_out, is_causal)
        return dQ, dK, dV, None  # 后两个None对应is_causal和TileSize的梯度(无需计算)


# 预编译反向计算函数,提升效率
def flash_bwd(Q, K, V, O, L, dO, is_causal=False):
    B, Tq, dk = Q.shape
    Tk = K.size(1)
    scale = 1.0 / (dk ** 0.5)

    # 1. 计算中间变量 D = O · dO^T(用于梯度链式法则)
    D = torch.sum(O * dO, dim=-1, keepdim=True)  # [B, Tq, 1]

    # 2. 重构注意力分数 S(基于前向保存的L)
    S = torch.matmul(Q, K.transpose(-1, -2)) * scale  # [B, Tq, Tk]
    if is_causal:
        mask = torch.triu(torch.ones(Tq, Tk, device=Q.device, dtype=torch.bool), diagonal=1)
        S = S
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值