FlashInfer中DeepSeek MLA的内核设计

作者:叶子豪
原文:https://zhuanlan.zhihu.com/p/25920092499
 


在过去的两个月中,FlashInfer团队在密集地迭代优化DeepSeek MLA性能,并支持了DeepSeek MLA在prefill/decode/extend各个阶段的attention算子(支持SM_80/SM_89/SM_90,和任意的page大小),并在v0.2.2版[1]本中集中发布。在这篇文章中,我们将介绍近期我们一些版本中所做的一些优化,以及背后的思路。抛砖引玉,希望能启发到大家。

Prefill阶段的MLA计算与普通的MHA计算大致相同,唯一的区别在于需要支持q/k和v/o使用不同的head_dim,只需要简单修改attention模版即可(见 feat: support deepseek prefill attention shape by yzh119 · Pull Request #765 · flashinfer-ai/flashinfer[2])。而Decode和Extend阶段,我们需要单独设计高效的融合Page Attention算子,以便高效地与低秩kv-cache进行attention计算,相对于其它attention变体而言,MLA的挑战主要来自于以下两点:

  • • 矩阵吸收之后较大的head dimension (512)

  • • 矩阵吸收之后key和value共用同一个矩阵,流水线需要重新设计

我们将介绍flashinfer团队在这些问题上的解决方案(我们相信还有很多更好的设计),希望能引起大家的讨论和思考。

背景——MLA

对于MLA模型本身,已经有了很多的解读,推荐阅读苏神的文章了解背景: 缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA [3]

背景——矩阵吸收

在推理阶段的注意力计算过程中,为了避免对 升维带来的显存带宽开销,我们可以改变矩阵的结合顺序(以下的推导省略了多头):,,以减少注意力算子的输入矩阵大小(I/O决定了解码阶段的效率),可以在注意力计算之前算好(在解码的过程中相对较小),这样整个Attention的计算转换为(同样,这里忽略了“头”的维度,其中和 均为多头矩阵)):

这个技巧叫做矩阵吸收,详细推导可以参考章老师的文章:DeepSeek-V2 高性能推理 (1):通过矩阵吸收十倍提速 MLA 算子[4]。通过矩阵吸收变换,我们可以将Multi-Head-Attention(MHA)计算转换成Multi-Query-Attention(MQA)计算,key和value共享一个 c_{kv} 矩阵,也就是MLA中的KV-Cache,其中没有“头”这个维度,query的部分有多头。

经过矩阵吸收变换之后,MLA的计算流程如下:

MLA计算流程(https://github.com/flashinfer-ai/flashinfer/pull/551),方框中是FlashInfer需要实现的MLA算子,其中c_kv被同时用做K和V,head_dim为512,rope部分的head_dim为64

MLA计算流程(https://github.com/flashinfer-ai/flashinfer/pull/551),方框中是FlashInfer需要实现的MLA算子,其中c_kv被同时用做K和V,head_dim为512,rope部分的head_dim为64

可以用代码描述为:

q_pe = W_QR(c_q)
q_nope = W_UQ_UK(c_q)
output = W_UV_O(MQA(q_pe, q_nope, c_kv, k_pe))

因此我们的目标是设计一个MQA算子,qk计算部分需要加上rope的维度(这块只需要简单地修改算子即可),重用key-cache和value-cache矩阵,同时需要支持较大的head_dim(512)。

是否需要融合算子?

在实现这个算子之前,我们先思考一个问题,我们是否需要这样一个融合算子:FlashAttention最初设计的初衷是减少对softmax矩阵储存以及方寸的开销,其大小正比于  ,占整体I/O的比值为:

而对于推理阶段而言 l_q 其实是非常小的,如果大家手写过推理阶段的算子,其实会发现不融合qk和pv两阶段的计算也能取得不错的效果(目前sglang的triton后端有一个这样的实现,可以参考flashinfer中的这个rfc:

https://github.com/flashinfer-ai/flashinfer/issues/707

但是对于MLA而言,融合是必要的:

  • • MLA有较大的group ratio:  ,会增大softmax的占比

  • • MLA复用了key和value矩阵,因此如果我们不融合两阶段的话,前后两个算子将各自访问一遍KV-Cache,如果硬件的cache不够大的话,带宽利用率将无法超过50%。

对sglang的MLA Triton的实现做profiling我们也能得到相应的结论(后文中我们将展示实验):融合是必要的。

MQA/GQA优化——查询与头维度的融合

对于MQA和GQA的解码阶段,一种常用的优化技巧是把共用一个KV头的所有QO头,与query的行数融合(因为他们需要跟相同的KV-Cache做Attention计算):

Head Group融合,来自FlashInfer论文的Appendix A: https://arxiv.org/pdf/2501.01005

Head Group融合,来自FlashInfer论文的Appendix A: https://arxiv.org/pdf/2501.01005

这样的效果是增加了有效的行数,增加了算子密度,自回归解码阶段虽然说查询的长度是1,但是经过Head Group融合之后,有效行数增大到  。

对于DeepSeek-V3中的MLA,因为 ,如果不做矩阵并行(TP),那么即便查询长度只有1,算子的密度也能达到256(注意K和V复用,因此访存量小一半),如果做推测解码,则很容易达到计算密集。在FlashInfer中,我们使用Head Group融合来增加MQA/GQA算子的算子密度。

挑战1——更大的Head Dimension

写过FlashAttention的同学都知道,我们通常希望把输出矩阵O放在寄存器中,这样可以省去频繁从其它内存层级中读写的开销(在blackwell上我们则没有这样的顾虑,但是对于先前的架构,tensor cores的输出都在寄存器中)。

如果我们让一个warp group(128个线程)负责64行,那么在head_dim=512的情况下,输出一共需要  个fp32寄存器,平分到每个线程则是  个寄存器,超出了NVIDIA GPU的限制。

H100上的wgmma指令最小大小是64行,因此我们希望不要在行这一维度切割,相反,我们在head_dim维度切割,每256分到一个warp group(  ),这样每个线程需要128个寄存器来储存O的中间结果,刚好合适。

这样的分块方式可以完美并行FlashAttention的第二阶段 () ,但是两个分块都需要相同的 P 矩阵(由第一阶段产生),此处我们有几种选择:

  • • 增加一个warp group(这样一共有三个consumer warp group)单独计算  。

  • • 仍然使用这两个warp group,但是共同计算  ,注意到第二阶段需要完整的 P 矩阵,因此我们需要借助shared memory进行一次all-gather操作:

flashinfer在此处选择了后一种方式,下图是具体的计算流程(在 https://github.com/flashinfer-ai/flashinfer/pull/804 中实现)。

Head Dimension划分逻辑

Head Dimension划分逻辑

其每个SM的shared memory占用为

针对不同硬件的shared memory大小,我们需要分别选择不同的KV分块大小和流水线深度。在flashinfer中,我们对于SM80(A100),SM89(4090,Ada6000, L40,...)和SM90(H100, H200,...)进行了支持,并相应地选择了参数。

挑战2——Hopper软件流水

FlashAttention3中提出了使用Hopper的异步wgmma指令来重叠cuda cores和tensor cores的操作,尽管在MLA中,由于巨大的head_dim,cuda cores计算的占比要小得多(tensor cores计算的flops与head_dim成正比,而cuda cores上的计算与head_dim无关),但我们仍然观察到了接近10%的开销。因此我们希望用FA3中的(3.2)软件流水的方法来并行tensor cores和cuda cores计算。

对于MLA而言,我们发现如果把流水线深度开成2,那么发射两个wgmma之后就没有剩余的shared memory空间可以用来同时读取下一个分块的KV-Cache了,因此我们把KV分块大小减半(64->32),并增加流水线深度(2->4),这样发射了两个wgmma指令之后并不会阻塞生产者的工作,可以同时读取接下来的KV-Cache。

为什么FlashAttention没有这个问题呢,因为FlashAttention的K和V是分开的,天然分阶段,而MLA中K和V重用。FlashAttention中的2阶段实际上等价于4=2*2(K/V)阶段。

在FlashAttention3中,我们把V的shared memory空间给O重用(因为把O从寄存器写到global memory的过程中,需要使用shared memory转换格式),而在MLA中,我们遇到一样的问题,因为K和V是同一个矩阵,如果把V的空间给O,那么会阻塞下一组work中的KV读取。

在flashinfer中,我们的解决方案是对O的shared memory也建立循环缓冲区,跟KV-Cache的阶段数一样,因为O需要的shared memory空间是每个KV-Cache分块的两倍,因此每次我们需要占用两个槽,为此我们需要建立一个mbarrier来保证内存序。整体流水线设计如图所示:

Hopper上的流水线设计,最上方一行代表生产者(一个warp group),下方代表消费者(一共两个warp group,按照head dimension切分的方式合作计算),虚线代表mbarrier。相同的颜色代表使用相同的shared memory空间。

Hopper上的流水线设计,最上方一行代表生产者(一个warp group),下方代表消费者(一共两个warp group,按照head dimension切分的方式合作计算),虚线代表mbarrier。相同的颜色代表使用相同的shared memory空间。

在这样的设计下,我们一共占用了221KB的shared memory(Hopper系列架构每个SM有228KB的shared memory)。

负载均衡

我们使用了flashinfer论文中的负载均衡调度器(执行顺序确定,无随机性)来应对同一个batch中不同request kv-cache长度的不均衡性:在每个生成步骤之前,我们使用最小堆把work分配到当前累计代价最小的CTA中。每次调度生成的调度信息可以被所有层重用,因此代价可以被均摊掉。

较长的kv可能会被分成多块,对于这些块,我们的kernel第一阶段把每个块的中间结果写进workspace buffer中,第二阶段将这些块的中间结果合并。在flashinfer的MLA的实现中(在 perf: dynamic split-k for MLA by yzh119 · Pull Request #863 · flashinfer-ai/flashinfer 中实现[5] ),第一个阶段和第二阶段写进了同一个kernel中,使用grid barrier分开。我们还使用了论文中(D.2章节)的直写优化,对于没有切割kv的work,调过workspace空间,直接写进最终的buffer,只对做了分块的kv写进指定的workspace空间中,这样的设计可以保证workspace buffer的大小不会太大(由硬件SM数量和head_dimension这些值决定),同时省掉了很多读写workspace的开销。

动态split-k调度器,来自flashinfer论文(3.3节):https://arxiv.org/pdf/2501.01005

动态split-k调度器,来自flashinfer论文(3.3节):https://arxiv.org/pdf/2501.01005

性能

我们在H100 SXM5(80GB),A100 SXM4(40GB)上进行了Page Attention解码算子层面的benchmark,并比较了sglang中的triton实现,结果如下:

H100上的MLA Page Attention Decoding Kernel性能,纵轴为达到的显存带宽,横轴为不同的实验设置,Page Size选择为1。

H100上的MLA Page Attention Decoding Kernel性能,纵轴为达到的显存带宽,横轴为不同的实验设置,Page Size选择为1。

对于H100,我们选取了sglang的Triton MLA实现作为baseline,并对比了flashinfer中的两种实现(FA2和FA3),FA2使用两个warp group,流水线深度2,KV分块大小64,FA3使用三个warp group,流水线深度4,KV分块大小为32。FA3能达到接近3TB/s的显存带宽(H100的显存带宽上限为3352GB/s)。

A100上的MLA Page Attention Decoding Kernel性能,纵轴为达到的显存带宽,横轴为不同的实验设置,Page Size选择为1。

A100上的MLA Page Attention Decoding Kernel性能,纵轴为达到的显存带宽,横轴为不同的实验设置,Page Size选择为1。

对于A100,我们选取了sglang的Triton MLA实现作为baseline。由于A100不支持Hopper的特性,我们只比较了flashinfer中的FA2实现:两个warp group,流水线深度2,KV分块大小32。由于A100不支持硬件原生的stmatri指令,allgather操作的开销较大,在最优情况下,能接近1TB/s的显存带宽(上线为1555GB/s)。

sglang对接

sglang团队在最近几周中积极对接了flashinfer的MLA实现(因为sm_scale的原因踩了不少的坑,导致中间一度怀疑算子实现问题,实际上是sm_scale的错误),并在各类end-to-end任务中展现出了显著的加速:

  • • Refactor flashinfer logic for deepseek v3 and fix accuracy bug by Fridge003 · Pull Request #3785 · sgl-project/sglang[6]

  • • DeepSeek MTP with flashinfer backend by jsun-012 · Pull Request #3775 · sgl-project/sglang[7]

  • • feat: support flashinfer mla attention for deepseek v3 by zhyncs · Pull Request #3550 · sgl-project/sglang[8]

我们将保持与开源社区的合作,并把性能推到更好。

后记

flashinfer的成长离不开开源社区的反馈,和大家的贡献。社区成员tsu-bin - Overview[9] 贡献了第一个版本的原型,并且详细地解释了MLA算子的计算流程(feat: support MLA decode by tsu-bin · Pull Request #551 · flashinfer-ai/flashinfer[10]),如果不是这个PR,我可能到现在都还不知道MLA的需求。我们也同样感谢这几周和我们一起debug调性能的sglang团队( @Yineng Zhang 等同学),和给我们提出关键建议的vllm团队。

在我们发布v0.2.2版本几个小时以后,DeepSeek发布了官方的FlashMLA实现[11],目前官方的版本支持page_size=64(flashinfer目前支持任意page size)。对于更大的page_size,我们可以使用tma减少生产者阶段的寄存器占用和指令发射压力,以及multi-casting来支持更高效的跨SM的KV-Cache访问(对num_local_heads=128的时候有用),我们将在接下来的几个版本支持。

我们保持开发路线公开透明([Roadmap] FlashInfer v0.2 to v0.3 · Issue #675 · flashinfer-ai/flashinfer[12]), 希望能和社区一起高效地迭代,共建LLM时代的基础架构。

引用链接

[1] v0.2.2版:https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.2.2
[2]eat: support deepseek prefill attention shape by yzh119 · Pull Request #765 · flashinfer-ai/flashinfer:https://github.com/flashinfer-ai/flashinfer/pull/765
[3]缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA :https://spaces.ac.cn/archives/10091
[4]DeepSeek-V2 高性能推理 (1):通过矩阵吸收十倍提速 MLA 算子:https://zhuanlan.zhihu.com/p/700214123
[5]perf: dynamic split-k for MLA by yzh119 · Pull Request #863 · flashinfer-ai/flashinfer 中实现:https://github.com/flashinfer-ai/flashinfer/pull/863
[6]Refactor flashinfer logic for deepseek v3 and fix accuracy bug by Fridge003 · Pull Request #3785 · sgl-project/sglang:https://github.com/sgl-project/sglang/pull/3785
[7]DeepSeek MTP with flashinfer backend by jsun-012 · Pull Request #3775 · sgl-project/sglang:https://github.com/sgl-project/sglang/pull/3775
[8]feat: support flashinfer mla attention for deepseek v3 by zhyncs · Pull Request #3550 · sgl-project/sglang:https://github.com/sgl-project/sglang/pull/3550
[9]tsu-bin - Overview:https://github.com/tsu-bin
[10]feat: support MLA decode by tsu-bin · Pull Request #551 · flashinfer-ai/flashinfer:https://github.com/flashinfer-ai/flashinfer/pull/551
[11]FlashMLA实现:https://github.com/deepseek-ai/FlashMLA/
[12][Roadmap] FlashInfer v0.2 to v0.3 · Issue #675 · flashinfer-ai/flashinfer:https://github.com/flashinfer-ai/flashinfer/issues/675

### DeepSeek MLA 技术概述 DeepSeek MLA 是由 DeepSeek 团队开发的一套先进机器学习加速框架,旨在简化大规模语言模型的应用开发过程。该平台集成了多种前沿技术和工具,使得开发者能够更高效地构建、训练以及部署基于自然语言处理的任务。 #### 工具与库推荐 对于希望利用 DeepSeek MLA 的开发者来说,有几款重要的开源项目值得特别关注: - **Hugging Face Transformers**:这是一个非常流行的 Python 库,包含了大量预训练的语言模型及其对应的实现代码[^1]。 - **LangChain**:专为创建复杂的对话系统而设计,支持多轮次交互逻辑的设计和管理,极大地方便了聊天机器人等应用场景下的功能扩展。 - **Gradio**:允许用户通过简单的 API 调用来迅速搭建起具有图形界面的操作面板,非常适合于原型验证阶段或是向非技术人员展示成果时使用。 这些工具不仅限于理论上的指导意义,在实际操作层面也提供了详尽的帮助文档和支持材料来辅助使用者更好地理解和掌握它们的功能特性。 #### 开发者资源获取途径 为了便于广大开发者群体获得必要的技术支持和服务保障,官方团队精心准备了一系列的学习资料和技术手册放在 GitHub 上供查阅下载[^2]。这里可以找到关于如何安装配置环境的具体说明文件,还有许多实用的小贴士可以帮助提高工作效率并解决常见问题。 此外,针对不同层次的需求——无论是初学者还是资深专家——都编写了相应的培训课程内容,涵盖了从基础概念到高级特性的方方面面。特别是有关优化模型表现力方面的建议更是弥足珍贵,因为这往往决定了最终产品的竞争力所在。 #### 示例应用案例分享 下面给出一段简单示例程序片段,展示了怎样借助上述提到的一些组件快速启动一个基于 Transformer 架构的文本分类器服务端口: ```python from transformers import pipeline, AutoTokenizer, TFAutoModelForSequenceClassification import gradio as gd tokenizer = AutoTokenizer.from_pretrained('distilbert-uncased-finetuned-sst-2-english') model = TFAutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english') classifier = pipeline(task='sentiment-analysis', model=model, tokenizer=tokenizer) def predict(text_input): result = classifier(text_input)[0] label = 'Positive' if float(result['score']) >= 0.5 else 'Negative' return f'The sentiment is {label} with confidence score of {result["score"]}' iface = gd.Interface(fn=predict, inputs="text", outputs="label") iface.launch() ``` 这段脚本首先加载了一个已经过情感分析任务微调过的 DistilBERT 模型实例,并定义好预测函数 `predict()` 来接收输入字符串参数后返回相应的情感倾向判断结果。最后一步则是运用 Gradio 创建 Web UI 接口以便直观测试效果。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值