Decoding Attention-LLM推理优化

目录

1 背景

2 结果

2.1 测试条件

2.2 设备规格

2.3 RTX3090

3 Decoding Attention

3.1 HGEMV(S = Q * K^T)

3.2 Softmax(P = Softmax(S))

3.3 HGEMV(O = P * V)

4 其他

4.1 Decoding Attention量化


1 背景

得益于flash-attention对于MHA(Multi Head Attention)计算的加速优化,使得LLM训练和推理的性能都获得了极大地提升。Tri Dao在Flash Attention v2的基础上又专门写了Flash Decoding来优化推理,主要修改了block划分方法,使用多个block来处理相同的注意力头,实现KV Cache的并行化加载,最后发射一个单独的kernel来重新缩放和合并结果。因此在小batch和超长序列情况下,Flash Decoding的推理性能提升很大。

然而Flash Attention应用于LLM推理,依然会存在如下问题:

  • Decoding阶段Tensor Core利用率低,以Ampere架构(NVIDIA A100、RTX3090等)为例,Tensor Core利用率只有1 / 16

  • 只支持FP16和BF16数据类型,不支持FP8、Int8、Int4等量化算法,不利于降低推理成本

  • 只支持通用Attention,不支持ALiBi等变种Attention

由于LLM推理Decoding阶段的seq_q始终是1,所以MHA每个头注意力的计算流程简化为两个HGEMV和一个Softmax,示意图如下。

注意到RTX3090在FP32乘累加的情况下,Tensor Core FP16算力只有CUDA Core FP32算力的2倍,这种情况下,使用手写CUDA Core Kernel来计算Decoding MHA,在保证精度的同时,无论是延迟还是硬件利用率都会取得一定的收益。另一方面,手写Decoding MHA Kernel不仅方便进一步采用KV Cache量化方法提升LLM推理的吞吐性能,进而降低推理成本,也可以支持更多的变种Attention(如ALiBi等)。

2 结果

本文主要参考ppl.llm.kernel.cudaflash-attention,使用CUDA Core优化MHA在LLM推理Decoding阶段的性能。目前在部分推理Decoding场景下,Decoding Attention的性能优于Flash Decoding(Flash Attention)和FlashInfer。Decoding Attention支持GQA(Group Query Attention)/ MQA(Multi Query Attention)和ALiBi(Attention with Linear Biases)推理场景,支持FP16和BF16数据类型。Decoding Attention提供了C++ API和Python API,代码开源在decoding_attention

2.1 测试条件

  • MHA:O = Softmax(Q * K^T) * V

  • CUDA:12.1

  • GPU:RTX3090

  • Flash Attention:v2.6.3

  • FlashInfer:v0.1.6

  • Head Num:32

  • Head Dim:128

  • Data Type:FP16

2.2 设备规格

RTX3090的设备规格如下。

Graphics Card

RTX3090

GPU Codename

GA102

GPU Architecture

Ampere

GPCs

7

TPCs

41

SMs

82

CUDA Cores / SM

128

CUDA Cores / GPU

10496

Tensor Cores / SM

4 (3rd Gen)

Tensor Cores / GPU

328 (3rd Gen)

GPU Boost Clock (MHz)

1695

Peak FP32 TFLOPS

35.6

Peak FP16 TFLOPS

35.6

Peak FP16 Tensor TFLOPS
with FP16 Accumulate

142

Peak FP16 Tensor TFLOPS
with FP32 Accumulate

71

Memory Interface

384-bit

Memory Clock (Data Rate)

19.5 Gbps

Memory Bandwidth

936 GB/sec

L1 Data Cache/Shared Memory

10496 KB

L2 Cache Size

6144 KB

Register File Size

20992 KB

2.3 RTX3090

(1)Seq Len

序列长度小于1536时,Decoding Attention的性能优于Flash Decoding(Flash Attention)和FlashInfer;序列长度大于1536时,Flash Decoding(Flash Attention)和FlashInfer的性能更优。

  • Batch Size:1

  • Seq Q:1

  • Seq K:Seq Len

(2)Batch Size

Batch Size无论大小,Decoding Attention性能都要优于Flash Decoding(Flash Attention)和FlashInfer。

  • Batch Size:Batch Size

  • Seq Q:1

  • Seq K:128

3 Decoding Attention

如上所述,LLM推理Decoding阶段的seq_q始终是1,所以MHA每个头注意力的计算流程简化为两个HGEMV和一个Softmax。Decoding Attention按batch和head划分block,Kernel主要分为三个部分,即HGEMV(S = Q * K^T)、Softmax(P = Softmax(S))和HGEMV(O = P * V),源码在decoding_attention

3.1 HGEMV(S = Q * K^T)

每个Group处理一个或多个seqlen_k,计算流程与HGEMV一致,可参考开源代码cuda_hgemv

    // S = Q * K^T
    T RQ[thread_elem_nums];

#pragma unroll
    for (size_t i = 0; i < thread_iters; ++i) {
        *(int4 *)(&RQ[i * thread_copy_elem_nums]) =
            *(int4 *)(&q_ptr[binfo.q_offset(params.q_row_stride, params.q_head_stride,
                                            (i * threads_per_group + group_lane_id) * thread_copy_elem_nums)]);
    }

    extern __shared__ float S_smem[];
    float S_max = -std::numeric_limits<float>::max();

#pragma unroll
    for (size_t base_seq_k = warp_id * groups_per_warp; base_seq_k < binfo.actual_seq_k;
         base_seq_k += groups_per_block) {
        size_t seq_k = base_seq_k + group_id;
        T RK[thread_elem_nums];

        float acc = 0.0;
        if (seq_k < binfo.actual_seq_k) {
#pragma unroll
            for (size_t i = 0; i < thread_iters; ++i) {
                *(int4 *)(&RK[i * thread_copy_elem_nums]) =
                    *(int4 *)(&k_ptr[binfo.k_offset(seq_k, params.k_row_stride, params.k_head_stride,
                                                    (i * threads_per_group + group_lane_id) * thread_copy_elem_nums)]);
            }

#pragma unroll
            for (size_t i = 0; i < thread_elem_nums; ++i) {
                if constexpr (std::is_same_v<T, half>) {
                    acc += (__half2float(RQ[i]) * __half2float(RK[i]));
                } else {
                    acc += (__bfloat162float(RQ[i]) * __bfloat162float(RK[i]));
                }
            }
        }

#pragma unroll
        for (size_t i = threads_per_group / 2; i >= 1; i /= 2) {
            acc += __shfl_xor_sync(shfl_mask, acc, i);
        }

        if (group_lane_id == 0 && seq_k < binfo.actual_seq_k) {
            acc *= params.scale_softmax;

            if (IsAlibi) {
                acc += (binfo.h_slope * (static_cast<int>(seq_k) - binfo.actual_seq_q - binfo.row_shift));
            }

            S_smem[seq_k] = acc;
            S_max = fmaxf(acc, S_max);
        }
    }

3.2 Softmax(P = Softmax(S))

首先根据上一步计算得到的每个Group中S的最大值进行Reduce,得到一行中S的最大值,然后在S_smem上计算每个seqlen_k对应的Softmax。

    // P = Softmax(S)
    __shared__ float softmax_smem[warps_per_block];

#pragma unroll
    for (size_t i = warp_size / 2; i >= 1; i /= 2) {
        S_max = fmaxf(S_max, __shfl_xor_sync(shfl_mask, S_max, i));
    }

    if (lane_id == 0) {
        softmax_smem[warp_id] = S_max;
    }

    __syncthreads();

    if (lane_id < warps_per_block) {
        S_max = softmax_smem[lane_id];
    } else {
        S_max = -std::numeric_limits<float>::max();
    }

#pragma unroll
    for (size_t i = warps_per_block / 2; i >= 1; i /= 2) {
        S_max = fmaxf(S_max, __shfl_xor_sync(shfl_mask, S_max, i));
    }

    S_max = __shfl_sync(shfl_mask, S_max, 0);

    float exp_sum = 0.0;
#pragma unroll
    for (size_t seq_k = threadIdx.x; seq_k < binfo.actual_seq_k; seq_k += threads_per_block) {
        S_smem[seq_k] -= S_max;
        S_smem[seq_k] = exp(S_smem[seq_k]);
        exp_sum += S_smem[seq_k];
    }

#pragma unroll
    for (size_t i = warp_size / 2; i >= 1; i /= 2) {
        exp_sum += __shfl_xor_sync(shfl_mask, exp_sum, i);
    }

    if (lane_id == 0) {
        softmax_smem[warp_id] = exp_sum;
    }

    __syncthreads();

    if (lane_id < warps_per_block) {
        exp_sum = softmax_smem[lane_id];
    }

#pragma unroll
    for (size_t i = warps_per_block / 2; i >= 1; i /= 2) {
        exp_sum += __shfl_xor_sync(shfl_mask, exp_sum, i);
    }
    exp_sum = __shfl_sync(shfl_mask, exp_sum, 0);

#pragma unroll
    for (size_t seq_k = threadIdx.x; seq_k < binfo.actual_seq_k; seq_k += threads_per_block) {
        S_smem[seq_k] /= exp_sum;
    }

    __syncthreads();

3.3 HGEMV(O = P * V)

由于V矩阵存储的特殊性,此处每个Group计算V中每一行或多行的外积,然后Reduce Sum得到最终结果。

    // O = P * V
    T RV[thread_elem_nums];
    float RO[thread_elem_nums];

    memset(RO, 0, sizeof(RO));

#pragma unroll
    for (size_t base_seq_k = warp_id * groups_per_warp; base_seq_k < binfo.actual_seq_k;
         base_seq_k += groups_per_block) {
        size_t seq_k = base_seq_k + group_id;

        if (seq_k < binfo.actual_seq_k) {
#pragma unroll
            for (size_t i = 0; i < thread_iters; ++i) {
                *(int4 *)(&RV[i * thread_copy_elem_nums]) =
                    *(int4 *)(&v_ptr[binfo.k_offset(seq_k, params.v_row_stride, params.v_head_stride,
                                                    (i * threads_per_group + group_lane_id) * thread_copy_elem_nums)]);
            }

#pragma unroll
            for (size_t i = 0; i < thread_elem_nums; ++i) {
                if constexpr (std::is_same_v<T, half>) {
                    RO[i] += (S_smem[seq_k] * __half2float(RV[i]));
                } else {
                    RO[i] += (S_smem[seq_k] * __bfloat162float(RV[i]));
                }
            }
        }
    }

#pragma unroll
    for (size_t i = 0; i < thread_elem_nums; ++i) {
#pragma unroll
        for (size_t j = threads_per_group; j <= warp_size / 2; j *= 2) {
            RO[i] += __shfl_xor_sync(shfl_mask, RO[i], j);
        }
    }

    __syncthreads();

#pragma unroll
    for (size_t i = threadIdx.x; i < head_dim; i += threads_per_block) {
        S_smem[i] = 0.0;
    }

    __syncthreads();

    if (lane_id < threads_per_group) {
#pragma unroll
        for (size_t i = 0; i < thread_iters; ++i) {
#pragma unroll
            for (size_t j = 0; j < thread_copy_elem_nums; ++j) {
                atomicAdd(S_smem + (i * threads_per_group + lane_id) * thread_copy_elem_nums + j,
                          RO[i * thread_copy_elem_nums + j]);
            }
        }
    }

    __syncthreads();

#pragma unroll
    for (size_t i = threadIdx.x; i < head_dim; i += threads_per_block) {
        if constexpr (std::is_same_v<T, half>) {
            o_ptr[binfo.q_offset(params.o_row_stride, params.o_head_stride, i)] = __float2half(S_smem[i]);
        } else {
            o_ptr[binfo.q_offset(params.o_row_stride, params.o_head_stride, i)] = __float2bfloat16(S_smem[i]);
        }
    }

4 其他

4.1 计划

  • Kernel优化
  • KV Cache量化:FP8、Int8、Int4

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值