作者:叶子豪
原文:https://zhuanlan.zhihu.com/p/25920092499
在过去的两个月中,FlashInfer团队在密集地迭代优化DeepSeek MLA性能,并支持了DeepSeek MLA在prefill/decode/extend各个阶段的attention算子(支持SM_80/SM_89/SM_90,和任意的page大小),并在v0.2.2版[1]本中集中发布。在这篇文章中,我们将介绍近期我们一些版本中所做的一些优化,以及背后的思路。抛砖引玉,希望能启发到大家。
Prefill阶段的MLA计算与普通的MHA计算大致相同,唯一的区别在于需要支持q/k和v/o使用不同的head_dim,只需要简单修改attention模版即可(见 feat: support deepseek prefill attention shape by yzh119 · Pull Request #765 · flashinfer-ai/flashinfer[2])。而Decode和Extend阶段,我们需要单独设计高效的融合Page Attention算子,以便高效地与低秩kv-cache进行attention计算,相对于其它attention变体而言,MLA的挑战主要来自于以下两点:
-
• 矩阵吸收之后较大的head dimension (512)
-
• 矩阵吸收之后key和value共用同一个矩阵,流水线需要重新设计
我们将介绍flashinfer团队在这些问题上的解决方案(我们相信还有很多更好的设计),希望能引起大家的讨论和思考。
背景——MLA
对于MLA模型本身,已经有了很多的解读,推荐阅读苏神的文章了解背景: 缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA [3]
背景——矩阵吸收
在推理阶段的注意力计算过程中,为了避免对 升维带来的显存带宽开销,我们可以改变矩阵的结合顺序(以下的推导省略了多头):,,以减少注意力算子的输入矩阵大小(I/O决定了解码阶段的效率),可以在注意力计算之前算好(在解码的过程中相对较小),这样整个Attention的计算转换为(同样,这里忽略了“头”的维度,其中和 均为多头矩阵)):
这个技巧叫做矩阵吸收,详细推导可以参考章老师的文章:DeepSeek-V2 高性能推理 (1):通过矩阵吸收十倍提速 MLA 算子[4]。通过矩阵吸收变换,我们可以将Multi-Head-Attention(MHA)计算转换成Multi-Query-Attention(MQA)计算,key和value共享一个 c_{kv} 矩阵,也就是MLA中的KV-Cache,其中没有“头”这个维度,query的部分有多头。
经过矩阵吸收变换之后,MLA的计算流程如下:
MLA计算流程(https://github.com/flashinfer-ai/flashinfer/pull/551),方框中是FlashInfer需要实现的MLA算子,其中c_kv被同时用做K和V,head_dim为512,rope部分的head_dim为64
可以用代码描述为:
q_pe = W_QR(c_q)
q_nope = W_UQ_UK(c_q)
output = W_UV_O(MQA(q_pe, q_nope, c_kv, k_pe))
因此我们的目标是设计一个MQA算子,qk计算部分需要加上rope的维度(这块只需要简单地修改算子即可),重用key-cache和value-cache矩阵,同时需要支持较大的head_dim(512)。
是否需要融合算子?
在实现这个算子之前,我们先思考一个问题,我们是否需要这样一个融合算子:FlashAttention最初设计的初衷是减少对softmax矩阵储存以及方寸的开销,其大小正比于 ,占整体I/O的比值为:
而对于推理阶段而言 l_q 其实是非常小的,如果大家手写过推理阶段的算子,其实会发现不融合qk和pv两阶段的计算也能取得不错的效果(目前sglang的triton后端有一个这样的实现,可以参考flashinfer中的这个rfc:
https://github.com/flashinfer-ai/flashinfer/issues/707
但是对于MLA而言,融合是必要的:
-
• MLA有较大的group ratio: ,会增大softmax的占比
-
• MLA复用了key和value矩阵,因此如果我们不融合两阶段的话,前后两个算子将各自访问一遍KV-Cache,如果硬件的cache不够大的话,带宽利用率将无法超过50%。
对sglang的MLA Triton的实现做profiling我们也能得到相应的结论(后文中我们将展示实验):融合是必要的。
MQA/GQA优化——查询与头维度的融合
对于MQA和GQA的解码阶段,一种常用的优化技巧是把共用一个KV头的所有QO头,与query的行数融合(因为他们需要跟相同的KV-Cache做Attention计算):
Head Group融合,来自FlashInfer论文的Appendix A: https://arxiv.org/pdf/2501.01005
这样的效果是增加了有效的行数,增加了算子密度,自回归解码阶段虽然说查询的长度是1,但是经过Head Group融合之后,有效行数增大到 。
对于DeepSeek-V3中的MLA,因为 ,如果不做矩阵并行(TP),那么即便查询长度只有1,算子的密度也能达到256(注意K和V复用,因此访存量小一半),如果做推测解码,则很容易达到计算密集。在FlashInfer中,我们使用Head Group融合来增加MQA/GQA算子的算子密度。
挑战1——更大的Head Dimension
写过FlashAttention的同学都知道,我们通常希望把输出矩阵O放在寄存器中,这样可以省去频繁从其它内存层级中读写的开销(在blackwell上我们则没有这样的顾虑,但是对于先前的架构,tensor cores的输出都在寄存器中)。
如果我们让一个warp group(128个线程)负责64行,那么在head_dim=512的情况下,输出一共需要 个fp32寄存器,平分到每个线程则是 个寄存器,超出了NVIDIA GPU的限制。
H100上的wgmma指令最小大小是64行,因此我们希望不要在行这一维度切割,相反,我们在head_dim维度切割,每256分到一个warp group( ),这样每个线程需要128个寄存器来储存O的中间结果,刚好合适。
这样的分块方式可以完美并行FlashAttention的第二阶段 () ,但是两个分块都需要相同的 P 矩阵(由第一阶段产生),此处我们有几种选择:
-
• 增加一个warp group(这样一共有三个consumer warp group)单独计算 。
-
• 仍然使用这两个warp group,但是共同计算 ,注意到第二阶段需要完整的 P 矩阵,因此我们需要借助shared memory进行一次all-gather操作:
flashinfer在此处选择了后一种方式,下图是具体的计算流程(在 https://github.com/flashinfer-ai/flashinfer/pull/804
中实现)。
Head Dimension划分逻辑
其每个SM的shared memory占用为
针对不同硬件的shared memory大小,我们需要分别选择不同的KV分块大小和流水线深度。在flashinfer中,我们对于SM80(A100),SM89(4090,Ada6000, L40,...)和SM90(H100, H200,...)进行了支持,并相应地选择了参数。
挑战2——Hopper软件流水
FlashAttention3中提出了使用Hopper的异步wgmma指令来重叠cuda cores和tensor cores的操作,尽管在MLA中,由于巨大的head_dim,cuda cores计算的占比要小得多(tensor cores计算的flops与head_dim成正比,而cuda cores上的计算与head_dim无关),但我们仍然观察到了接近10%的开销。因此我们希望用FA3中的(3.2)软件流水的方法来并行tensor cores和cuda cores计算。
对于MLA而言,我们发现如果把流水线深度开成2,那么发射两个wgmma之后就没有剩余的shared memory空间可以用来同时读取下一个分块的KV-Cache了,因此我们把KV分块大小减半(64->32),并增加流水线深度(2->4),这样发射了两个wgmma指令之后并不会阻塞生产者的工作,可以同时读取接下来的KV-Cache。
为什么FlashAttention没有这个问题呢,因为FlashAttention的K和V是分开的,天然分阶段,而MLA中K和V重用。FlashAttention中的2阶段实际上等价于4=2*2(K/V)阶段。
在FlashAttention3中,我们把V的shared memory空间给O重用(因为把O从寄存器写到global memory的过程中,需要使用shared memory转换格式),而在MLA中,我们遇到一样的问题,因为K和V是同一个矩阵,如果把V的空间给O,那么会阻塞下一组work中的KV读取。
在flashinfer中,我们的解决方案是对O的shared memory也建立循环缓冲区,跟KV-Cache的阶段数一样,因为O需要的shared memory空间是每个KV-Cache分块的两倍,因此每次我们需要占用两个槽,为此我们需要建立一个mbarrier来保证内存序。整体流水线设计如图所示:
Hopper上的流水线设计,最上方一行代表生产者(一个warp group),下方代表消费者(一共两个warp group,按照head dimension切分的方式合作计算),虚线代表mbarrier。相同的颜色代表使用相同的shared memory空间。
在这样的设计下,我们一共占用了221KB的shared memory(Hopper系列架构每个SM有228KB的shared memory)。
负载均衡
我们使用了flashinfer论文中的负载均衡调度器(执行顺序确定,无随机性)来应对同一个batch中不同request kv-cache长度的不均衡性:在每个生成步骤之前,我们使用最小堆把work分配到当前累计代价最小的CTA中。每次调度生成的调度信息可以被所有层重用,因此代价可以被均摊掉。
较长的kv可能会被分成多块,对于这些块,我们的kernel第一阶段把每个块的中间结果写进workspace buffer中,第二阶段将这些块的中间结果合并。在flashinfer的MLA的实现中(在 perf: dynamic split-k for MLA by yzh119 · Pull Request #863 · flashinfer-ai/flashinfer 中实现[5] ),第一个阶段和第二阶段写进了同一个kernel中,使用grid barrier分开。我们还使用了论文中(D.2章节)的直写优化,对于没有切割kv的work,调过workspace空间,直接写进最终的buffer,只对做了分块的kv写进指定的workspace空间中,这样的设计可以保证workspace buffer的大小不会太大(由硬件SM数量和head_dimension这些值决定),同时省掉了很多读写workspace的开销。
动态split-k调度器,来自flashinfer论文(3.3节):https://arxiv.org/pdf/2501.01005
性能
我们在H100 SXM5(80GB),A100 SXM4(40GB)上进行了Page Attention解码算子层面的benchmark,并比较了sglang中的triton实现,结果如下:
H100上的MLA Page Attention Decoding Kernel性能,纵轴为达到的显存带宽,横轴为不同的实验设置,Page Size选择为1。
对于H100,我们选取了sglang的Triton MLA实现作为baseline,并对比了flashinfer中的两种实现(FA2和FA3),FA2使用两个warp group,流水线深度2,KV分块大小64,FA3使用三个warp group,流水线深度4,KV分块大小为32。FA3能达到接近3TB/s的显存带宽(H100的显存带宽上限为3352GB/s)。
A100上的MLA Page Attention Decoding Kernel性能,纵轴为达到的显存带宽,横轴为不同的实验设置,Page Size选择为1。
对于A100,我们选取了sglang的Triton MLA实现作为baseline。由于A100不支持Hopper的特性,我们只比较了flashinfer中的FA2实现:两个warp group,流水线深度2,KV分块大小32。由于A100不支持硬件原生的stmatri指令,allgather操作的开销较大,在最优情况下,能接近1TB/s的显存带宽(上线为1555GB/s)。
sglang对接
sglang团队在最近几周中积极对接了flashinfer的MLA实现(因为sm_scale的原因踩了不少的坑,导致中间一度怀疑算子实现问题,实际上是sm_scale的错误),并在各类end-to-end任务中展现出了显著的加速:
-
• Refactor flashinfer logic for deepseek v3 and fix accuracy bug by Fridge003 · Pull Request #3785 · sgl-project/sglang[6]
-
• DeepSeek MTP with flashinfer backend by jsun-012 · Pull Request #3775 · sgl-project/sglang[7]
-
• feat: support flashinfer mla attention for deepseek v3 by zhyncs · Pull Request #3550 · sgl-project/sglang[8]
我们将保持与开源社区的合作,并把性能推到更好。
后记
flashinfer的成长离不开开源社区的反馈,和大家的贡献。社区成员tsu-bin - Overview[9] 贡献了第一个版本的原型,并且详细地解释了MLA算子的计算流程(feat: support MLA decode by tsu-bin · Pull Request #551 · flashinfer-ai/flashinfer[10]),如果不是这个PR,我可能到现在都还不知道MLA的需求。我们也同样感谢这几周和我们一起debug调性能的sglang团队( @Yineng Zhang 等同学),和给我们提出关键建议的vllm团队。
在我们发布v0.2.2版本几个小时以后,DeepSeek发布了官方的FlashMLA实现[11],目前官方的版本支持page_size=64(flashinfer目前支持任意page size)。对于更大的page_size,我们可以使用tma减少生产者阶段的寄存器占用和指令发射压力,以及multi-casting来支持更高效的跨SM的KV-Cache访问(对num_local_heads=128的时候有用),我们将在接下来的几个版本支持。
我们保持开发路线公开透明([Roadmap] FlashInfer v0.2 to v0.3 · Issue #675 · flashinfer-ai/flashinfer[12]), 希望能和社区一起高效地迭代,共建LLM时代的基础架构。
引用链接
[1]
v0.2.2版:https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.2.2
[2]
eat: support deepseek prefill attention shape by yzh119 · Pull Request #765 · flashinfer-ai/flashinfer:https://github.com/flashinfer-ai/flashinfer/pull/765
[3]
缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA :https://spaces.ac.cn/archives/10091
[4]
DeepSeek-V2 高性能推理 (1):通过矩阵吸收十倍提速 MLA 算子:https://zhuanlan.zhihu.com/p/700214123
[5]
perf: dynamic split-k for MLA by yzh119 · Pull Request #863 · flashinfer-ai/flashinfer 中实现:https://github.com/flashinfer-ai/flashinfer/pull/863
[6]
Refactor flashinfer logic for deepseek v3 and fix accuracy bug by Fridge003 · Pull Request #3785 · sgl-project/sglang:https://github.com/sgl-project/sglang/pull/3785
[7]
DeepSeek MTP with flashinfer backend by jsun-012 · Pull Request #3775 · sgl-project/sglang:https://github.com/sgl-project/sglang/pull/3775
[8]
feat: support flashinfer mla attention for deepseek v3 by zhyncs · Pull Request #3550 · sgl-project/sglang:https://github.com/sgl-project/sglang/pull/3550
[9]
tsu-bin - Overview:https://github.com/tsu-bin
[10]
feat: support MLA decode by tsu-bin · Pull Request #551 · flashinfer-ai/flashinfer:https://github.com/flashinfer-ai/flashinfer/pull/551
[11]
FlashMLA实现:https://github.com/deepseek-ai/FlashMLA/
[12]
[Roadmap] FlashInfer v0.2 to v0.3 · Issue #675 · flashinfer-ai/flashinfer:https://github.com/flashinfer-ai/flashinfer/issues/675