最近比较仔细的阅读了Mla相关的优化,下面从下面几个方面梳理一下该部分的一些内容:
- deekseek的模型结构以及mla部分的计算流程简介;
- MQA,GQA的kv cache优化思路以及网络结构对比;
1.1简介
DeepSeek-V3的基本架构仍然基于Transformer框架,为了实现高效推理和经济高效的训练,DeepSeek-V3还采用了MLA(多头潜在注意力)。
MHA(多头注意力)通过多个注意力头并行工作捕捉序列特征,但面临高计算成本和显存占用;MLA(多头潜在注意力)则通过低秩压缩优化键值矩阵,降低显存占用并提高推理效率。
1.2 多头注意力(MHA)
多头注意力(MHA)是Transformer模型架构中的一个核心组件,它允许模型在处理输入序列时能够同时关注来自不同位置的不同表示子空间的信息。
MHA通过将输入向量分割成多个并行的注意力“头”,每个头独立地计算注意力权重并产生输出,然后将这些输出通过拼接和线性变换进行合并以生成最终的注意力表示。
多头注意力如何进行Q,K,V计算?多头注意力(MHA)通过线性变换将张量分别转换为查询(Q),键(K)和值(V)矩阵,每个矩阵再被分割成多个头进行并行处理。
- 输入变换:输入序列首先通过三个不同的线性变换层,分别得到查询(Query)、键(key)和值(value)矩阵,这些变换通常是通过全连接层实现的。
- 分头:将查询、键和值分成多个头(即多个子空间),每个头具有不同的线性变换参数。
- 注意力计算:对于每个头,都执行一次缩放点积注意力(scaled Dot-Product Atention)运算,具体来说,计算QK的点积经过缩放,加上偏置后,使用softmax函数得到注意力权重。这些权重用于加权值矩阵,生成加权和作为每个头的输出。
- 拼接与融合:将所有头的输出拼接在一起,形成一个长向量。然后,对拼接后的向量进行一个最终的线性变换,以整合来自不同头的信息,得到最终的多头注意力输出。
MHA计算过程图:
其中 dkd_{k}dk 表示每个头的维度。
多头注意力机制和注意力机制区别是什么?多头注意力机制通过引入多个并行的注意力头,提高了模型对输入数据的全面捕捉和处理能力,使其在处理大规模数据和复杂任务时更具有优势。
- 注意力机制:通过聚焦于关键信息,提高了模型对输入数据的理解和处理能力;
- 多头注意力机制:通过并行处理和集成多个注意力头的结果,从不同的角度捕捉数据的多样性,进一步增强了模型的学习能力和表达力。
1.3 MQA与GQA
- Multi-Query Attention(MQA):所有的查询头(Query Heads)共享相同的 key 和 value。即消减了 Flops,也降低了cache,并且压缩了频繁矩阵拼接的IO耗时。
- Group-Query Attenation(GQA):是 MQA 的改进版,他通过在多个查询头之间共享 key 和value,在 MHA 和 MQA 之间找到了一种折中方案。
GQA旨在推理速度和模型质量之间取得更好的平衡,减少 MQA 带来的模型质量下降问题,同时保留比 MHA 更快的推理速度。DeepseekV1 67B,llama2 70B 和 llama3 系列全部都用了 GQA。
下面来看一下 MQA 与 GQA 带来的 kv cache压缩效果:
首先给出变量的含义:ddd 代表输入维度(input dim),nhn_hnh 代表头数(head数),dhd_hdh 代表每个头的维度,hhh_hhh代表输入的第t个向量,lll 代表transformer的层数。
对于标准的MHA而言,对于每一个token,KV cache占用的缓存的大小为 2nhdhl2n_hd_hl2nhdhl 。其中2为fp16的字节数。
对于 MQA 来说,缓存大小变为 2dhl2d_hl2dhl ,相比于 MHA 的 2nhdhl2n_hd_hl2nhdhl 而言大大减少,但是模型精度也会差一些;
对于 GQA 来说,缓存大小变为 2ngdhl2n_gd_hl2ngdhl ,其中 ngn_gng 代表head的分组数。
MQA 与 GQA 如何进行多头 KV 共享 ,或许有的时候会有这个疑问,但是当再看看模型结构的时候,心里就没有这种疑问,这是一个标准的MHA网络的attenation 网络层结构:
llama2-7b网络层结构:
model: LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(32000, 4096)
(layers): ModuleList(
(0-5): 6 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
)
...
)
下面是一个GQA的llama3-8b的网络层结构:
model: LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 4096)
(layers): ModuleList(
(0-5): 6 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
)
...
)
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)
GQA的llama2-72b的网络层结构:
model: LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(32000, 8192, padding_idx=0)
(layers): ModuleList(
(0-5): 6 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=8192, out_features=8192, bias=False)
(k_proj): Linear(in_features=8192, out_features=1024, bias=False)
(v_proj): Linear(in_features=8192, out_features=1024, bias=False)
(o_proj): Linear(in_features=8192, out_features=8192, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
...
)
两种网络在 k_proj 和 v_proj 的输出维度是不同的,MHA 是与 q 是相同的,但是 GQA 的这部分并不是相同的,而是多对一的关系,在 llama3-8b 种有32个头,每个头的维度为4096/32=128。kv 1024/128=8组,在8B模型种32个头用8组,4个头共用一组,同理可以推算72B。对于MQA的模型,则只有一组kv,所有头共用,对应的out-feature就是头的维度。
对于deepseek v3的网络结构如下:
model: DeepseekV3ForCausalLM(
(model): DeepseekV3Model(
(embed_tokens): Embedding(129280, 7168)
(layers): ModuleList(
(0-2): 3 x DeepseekV3DecoderLayer(
(self_attn): DeepseekV3Attention(
(q_a_proj): Linear(in_features=7168, out_features=1536, bias=False)
(q_a_layernorm): DeepseekV3RMSNorm()
(q_b_proj): Linear(in_features=1536, out_features=24576, bias=False)
(kv_a_proj_with_mqa): Linear(in_features=7168, out_features=576, bias=False)
(kv_a_layernorm): DeepseekV3RMSNorm()
(kv_b_proj): Linear(in_features=512, out_features=32768, bias=False)
(o_proj): Linear(in_features=16384, out_features=7168, bias=False)
(rotary_emb): DeepseekV3YarnRotaryEmbedding()
)
...
)
)
(lm_head): Linear(in_features=7168, out_features=129280, bias=False)
)
从上面的模型结构可以看出,GQA直接通过降低输出的 out_features来直接降低kvcache。对于deepseek v3上看模型结构,似乎并没有那么明显,那么deepseek v3又是怎么去降低kv cache呢。对比如下图:
MLA借用了一个低秩分解,其中 WDKVW^{DKV}WDKV 是降维矩阵, WUKW^{UK}WUK与 WUVW^{UV}WUV是升维矩阵,其中作用原理与 GQA 很相似。也是通过改变输出通道来降低中间过程参数。这个方法其实主要是应用于降低模型的参数量,比如在CNN模型中的深度可分离卷积,这里被应用到了降低Kvcache,也是比较巧妙。
1.4 MLA详细计算流程
接着我们顺着简介中的图,顺一下mla的计算流程:
首先我们看一下mla的完整计算公式:
在deepseek中的cache就是cache蓝色的框框部分。
对图中公式的变量做如下解释说明:
- dhd_hdh:单个head的向量维度;
- nhn_hnh:是每层head的数量;
- dcd_cdc:MLA低秩压缩的维度,论文中取值:dc=4×dhd_c=4\times d_hdc=4×dh;
- ddd:隐藏层维度:d=dh×nhd=d_h \times n_hd=dh×nh;
- WDKVW^{DKV}WDKV :低秩变换矩阵;
首先来看一下 KV 的计算过程:
首先公式(41)对输入 hth_tht 做了一个低秩压缩,将 ddd 维的输入经过 WDKVW^{DKV}WDKV 变换后压缩成 dcd_cdc 维的 ctKVc_t^{KV}ctKV 。deepseek-v3 中 d=7168d=7168d=7168 ,dc=512d_c=512dc=512。然后通过公式(42)和公式(45)两个变换矩阵将KV的维度扩展变高。
再看一下 Q 的计算过程:
公式(37)(38)类似 KV 的逻辑,通过两个矩阵也做了一层低秩变换,这一步 Q 的变换看是为了减少模型的参数量。在 Deepseek-v3 中dq=1536d_q=1536dq=1536。是 KV 压缩维度 dcd_cdc 的3倍,当时相当于 d=7168 还是压缩了不少。
q, k增加 Rope 位置编码:
我们注意到在增加Rope位置编码并没有在上述计算出的 qtCq_t^CqtC,ktCk_t^CktC 的基础上乘以 Rope 的对角矩阵,而是单独计算了两个带着位置编码的 qtRq_t^RqtR 和 ktRk_t^RktR。其中需要注意的事, qtRq_t^RqtR 和 ktRk_t^RktR 的维度比较小,为单个Attenation Head维度的一半。dhR=dh/2=128/2=64d_h^R=d_h/2=128/2=64dhR=dh/2=128/2=64 。这部分计算的 ktRk_t^RktR 实际上MQA的计算方式的一种,同一层中,所有的Head共享同一个k。
然后按照公式(40)(44)跟已经计算的 qtCq_t^CqtC ktCk_t^CktC 拼接,构成完整的 qtq_tqt ktk_tkt 向量。
到目前为止,我们得到的 q,k 包括两部分拼接而成,一部分是做了低秩压缩得到的 q,k 向量,一部分是增加了Rope位置编码的 q,k 向量。后面这部分是基于MQA 方式计算得到的,所有Head共享1个k。
其实看到这里,mla 中的 QKV 计算其实分成了两步,一步是低秩压缩,另外一步是解低秩压缩,虽然降低了存储成本,但是其实感觉增加了计算,因为在真正计算 attenation 的时候,需要将 qkv 分量都先解压缩然后再进行计算,这其实并不是很高效的,怎么去节约这部分的时间呢,引入了一个概念,矩阵吸收:
例如在常规计算中: x1′=Px1{x_1}' =Px_1x1′=Px1 , x2′=Qx2{x_2}' =Qx_2x2′=Qx2 ,则有:x1′Tx2′=(Px1)T∗(Qx2)=x1TPTQx2{x_1}^{'T}x_{2}^{'}=(Px_1)^T*(Qx_2)=x_1^TP^TQx_2x1′Tx2′=(Px1)T∗(Qx2)=x1TPTQx2,而应用到attenation的计算中(假如没有引入Rope计算:)则为:
qt,iT×kj,i=(W(i)UQctQ)T×W(i)UKcjKV=(ctQ)T×(W(i)UQ)TW(i)UK×cjKVq_{t,i}^T \times k_{j,i} = \left({W}_{(i)}^{UQ} c_t^Q\right)^T \times {W}_{(i)}^{UK} c_j^{KV} = \left(c_t^Q\right)^T \times \left({W}_{(i)}^{UQ}\right)^T {W}_{(i)}^{UK} \times c_j^{KV}qt,iT×kj,i=(W(i)UQctQ)T×W(i)UKcjKV=(ctQ)T×(W(i)UQ)TW(i)UK×cjKV
不加Rope,则可以提前计算好 $\left({W}{(i)}{UQ}\right)T {W}{(i)}^{UK} $ 。这样的好处是,我们只需要存储 $ c_j^{KV}$ 这个低阶矩阵,而不需要存储 W(i)UK×cjKV{W}_{(i)}^{UK} \times c_j^{KV}W(i)UK×cjKV 这个矩阵。但是现在加上了Rope,则 (W(i)UQ)T\left({W}_{(i)}^{UQ}\right)^T(W(i)UQ)T 和 W(i)UK{W}_{(i)}^{UK}W(i)UK 之间增加了一个融合了相对位置的变量
R\mathcal{R}R ,如下述公式所示:
qt,iT×kj,i=(RtW(i)UQctQ)T×RjW(i)UKcjKV=(ctQ)T×(W(i)UQ)TRtTRjW(i)UK×cjKVq_{t,i}^T \times k_{j,i} = \left(\mathcal{R}_t W_{(i)}^{UQ} c_t^Q\right)^T \times \mathcal{R}_j W_{(i)}^{UK} c_j^{KV} = \left(c_t^Q\right)^T \times \left(W_{(i)}^{UQ}\right)^T \mathcal{R}_t^T \mathcal{R}_j W_{(i)}^{UK} \times c_j^{KV}qt,iT×kj,i=(RtW(i)UQctQ)T×RjW(i)UKcjKV=(ctQ)T×(W(i)UQ)TRtTRjW(i)UK×cjKV
而中间这个分量是随着相对位置变化而变化的,不是个固定矩阵,因此不能被提前计算好,所以论文说Rope与低秩变换不兼容。
为了引入位置编码,作者在一个很小的维度下,用MQA方式计算了q,k,也就是在每层网络中,所有head只计算一个k。引入位置编码的向量维度取得比较小为 dh/2=128/2=64d_h/2=128/2=64dh/2=128/2=64 ,所以最终q,k向量通过两部分拼接而成,计算权重时,由前后两部分分别相乘再相加得到:
qt,iT×kj,i=[qt,iC;qt,iR]T×[kj,iC;ktR]=qt,iCkj,iC+qt,iRktRq_{t,i}^{T} \times k_{j,i} = [q_{t,i}^{C}; q_{t,i}^{R}]^{T} \times [k_{j,i}^{C}; k_{t}^{R}] = q_{t,i}^{C}k_{j,i}^{C} + q_{t,i}^{R}k_{t}^{R}qt,iT×kj,i=[qt,iC;qt,iR]T×[kj,iC;ktR]=qt,iCkj,iC+qt,iRktR 。
前一项 qt,iCkj,iCq_{t,i}^{C}k_{j,i}^{C}qt,iCkj,iC 按公式6计算,通过矩阵吸收处理,缓存一个 ctKVc_t^{KV}ctKV 和一个共享的带有位置编码的 ktRk_t^RktR 。
ok,说完一大通,继续进行图解:整个更细致的带数据维度的计算流程如下图:
我们联系上文的deepseek 网络结构来看看,其中 WDQW^{DQ}WDQ 就是q_a_proj, WUQW^{UQ}WUQ 和 WQRW^{QR}WQR 一起为q_b_proj, WDKVW^{DKV}WDKV 和 WKRW^{KR}WKR 一起为kv_a_proj_with_mqa, WUKW^{UK}WUK 和 WUVW^{UV}WUV 一起为kv_b_proj。
矩阵吸收之后( WUKW^{UK}WUK 吸收进了 WUQW^{UQ}WUQ , WUVW^{UV}WUV 吸收 WOW^{O}WO )的流程图如:
其中推理阶段的 cache 没有变,位置编码计算的逻辑也没有变,但是对于部分权重的shape是有变化的,在涉及使用mla,部分的权重会进行重新计算。
参考与鸣谢:
在学习的过程中,参考了很多人的文章与博客,在这里表示十分感谢。主要有以下:
[1]https://zhuanlan.zhihu.com/p/703862723
[2]https://zhuanlan.zhihu.com/p/16730036197
[3]https://zhuanlan.zhihu.com/p/697781431
[4]https://www.bilibili.com/video/BV1jA9HYfEAC?vd_source=0d09e05dd3c1c7601dabed01875dc88e&spm_id_from=333.788.player.switch