FlashInfer中DeepSeek MLA的内核设计

作者:叶子豪
原文:https://zhuanlan.zhihu.com/p/25920092499
 


在过去的两个月中,FlashInfer团队在密集地迭代优化DeepSeek MLA性能,并支持了DeepSeek MLA在prefill/decode/extend各个阶段的attention算子(支持SM_80/SM_89/SM_90,和任意的page大小),并在v0.2.2版[1]本中集中发布。在这篇文章中,我们将介绍近期我们一些版本中所做的一些优化,以及背后的思路。抛砖引玉,希望能启发到大家。

Prefill阶段的MLA计算与普通的MHA计算大致相同,唯一的区别在于需要支持q/k和v/o使用不同的head_dim,只需要简单修改attention模版即可(见 feat: support deepseek prefill attention shape by yzh119 · Pull Request #765 · flashinfer-ai/flashinfer[2])。而Decode和Extend阶段,我们需要单独设计高效的融合Page Attention算子,以便高效地与低秩kv-cache进行attention计算,相对于其它attention变体而言,MLA的挑战主要来自于以下两点:

  • • 矩阵吸收之后较大的head dimension (512)

  • • 矩阵吸收之后key和value共用同一个矩阵,流水线需要重新设计

我们将介绍flashinfer团队在这些问题上的解决方案(我们相信还有很多更好的设计),希望能引起大家的讨论和思考。

背景——MLA

对于MLA模型本身,已经有了很多的解读,推荐阅读苏神的文章了解背景: 缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA [3]

背景——矩阵吸收

在推理阶段的注意力计算过程中,为了避免对 升维带来的显存带宽开销,我们可以改变矩阵的结合顺序(以下的推导省略了多头):,,以减少注意力算子的输入矩阵大小(I/O决定了解码阶段的效率),可以在注意力计算之前算好(在解码的过程中相对较小),这样整个Attention的计算转换为(同样,这里忽略了“头”的维度,其中和 均为多头矩阵)):

这个技巧叫做矩阵吸收,详细推导可以参考章老师的文章:DeepSeek-V2 高性能推理 (1):通过矩阵吸收十倍提速 MLA 算子[4]。通过矩阵吸收变换,我们可以将Multi-Head-Attention(MHA)计算转换成Multi-Query-Attention(MQA)计算,key和value共享一个 c_{kv} 矩阵,也就是MLA中的KV-Cache,其中没有“头”这个维度,query的部分有多头。

经过矩阵吸收变换之后,MLA的计算流程如下:

MLA计算流程(https://github.com/flashinfer-ai/flashinfer/pull/551),方框中是FlashInfer需要实现的MLA算子,其中c_kv被同时用做K和V,head_dim为512,rope部分的head_dim为64

MLA计算流程(https://github.com/flashinfer-ai/flashinfer/pull/551),方框中是FlashInfer需要实现的MLA算子,其中c_kv被同时用做K和V,head_dim为512,rope部分的head_dim为64

可以用代码描述为:

q_pe = W_QR(c_q)
q_nope = W_UQ_UK(c_q)
output = W_UV_O(MQA(q_pe, q_nope, c_kv, k_pe))

因此我们的目标是设计一个MQA算子,qk计算部分需要加上rope的维度(这块只需要简单地修改算子即可),重用key-cache和value-cache矩阵,同时需要支持较大的head_dim(512)。

是否需要融合算子?

在实现这个算子之前,我们先思考一个问题,我们是否需要这样一个融合算子:FlashAttention最初设计的初衷是减少对softmax矩阵储存以及方寸的开销,其大小正比于  ,占整体I/O的比值为:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值