KV-Cache技术小结(MHA,GQA,MQA,MLA)

个人博客位置: http://myhz0606.com/article/kv-cache

1 背景

KV-cache技术是目前LLMVLLM等自回归模型常用的避免冗余计算的手段。但引入该技术需要额外的存储成本。原生的kv-cache所需的存储成本与生成的token长度成正比,是目前长文本生成的主要瓶颈之一。目前针对如何降低KV-cache的存储成本激起大量研究者广泛关注。GQA (Group Query Attention),MQA (Multi Query Attention),MLA (Multi-Head Latent Attention)是目前常用的方法。本文将从经典的casual attention出发,阐述kv-cache的必要性,及目前常见优化kv-cache的手段。

TL,DR

image from deepseekv2 tech report
不同attention方法 KV cache的存储单元数量

KV-cache存储单元数量
Casual Attention K ≤ t , V ≤ t K_{\leq t},V_{\leq t} Kt,Vt 2 t d 2td 2td
MHA { K ≤ t ; i , V ≤ t ; i ∣ i = 1 , 2 , ⋯   , n h } \{K_{\leq t;i} ,V_{\leq t;i} |i=1,2,\cdots ,n_h\} {Kt;i,Vt;ii=1,2,,nh} 2 t n h d h 2 t n_hd_h 2tnhdh
GQA { K ≤ t ; g i , V ≤ t ; g i ∣ i = 1 , 2 , ⋯   , n g } \{K_{\leq t;g_i} ,V_{\leq t;g_i} |i=1,2,\cdots ,n_g\} {Kt;gi,Vt;gii=1,2,,ng} 2 t n g d h 2 t n_gd_h 2tngdh
MQA { K ≤ t ; g i , V ≤ t ; g i ∣ g i = 1 } \{K_{\leq t;g_i} ,V_{\leq t;g_i} |g_i=1\} {Kt;gi,Vt;gigi=1} 2 t d h 2 t d_h 2tdh
MLA C ≤ t k v , K ≤ t R C^{kv}_{\leq t} ,K^R_{\leq t} Ctkv,KtR t ( d c + d h R ) t (d_c + d_h^R) t(dc+dhR)

2 经典Casual-Attention的KV-Cache工作机制

假定当前层attention的输入为 X = [ x 1 ; x 2 ; ⋯   ; x T ] , X ∈ R T × d X=[x_1;x_2;\cdots ;x_T], X\in \mathbb{R} ^ {T\times d} X=[x1;x2;;xT],XRT×d, T T T为sequence的长度。通过 W q , W k , W v W_q,W_k,W_v Wq,Wk,Wv 3个线性层得到query ,key,value。

在这里插入图片描述

Q = [ q 1 q 2 ⋮ q T ] , K = [ k 1 k 2 ⋮ k T ] , V = [ v 1 v 2 ⋮ v T ] q t , k t , v t ∈ R 1 × d (1) Q=\begin{aligned} &\begin{bmatrix}q_1 \\ q_2 \\ \vdots \\ q_T \end{bmatrix}, K=\begin{bmatrix}k_1 \\ k_2 \\ \vdots \\ k_T \end{bmatrix}, V=\begin{bmatrix}v_1 \\ v_2 \\ \vdots \\ v_T \end{bmatrix} \\ & q_t,k_t,v_t \in \mathbb{R}^{1\times d} \end{aligned} \tag{1} Q= q1q2qT K= k1k2kT V= v1v2vT qt,kt,vtR1×d(1)

随后通过标准的casual attention机制,得到输出 Y Y Y

在训练阶段,为了通过teaching forcing技巧进行并行化训练,引入一个causal mask M M M,来保证 t t t位置的token只看的到 ≤ t \leq t t的token。这个阶段没有kv-cache。

在这里插入图片描述
[ y 1 y 2 ⋮ y T ] = s o f t m a x ( [ q 1 q 2 ⋮ q T ] [ k 1 T , k 2 T , ⋯   , k T T ] d + M ) [ v 1 v 2 ⋮ v T ] (2) \begin{bmatrix}y_1 \\ y_2 \\ \vdots \\ y_T \end{bmatrix} = \mathrm{softmax} \bigg ( \frac{\begin{bmatrix}q_1 \\ q_2 \\ \vdots \\ q_T \end{bmatrix}\begin{bmatrix}k_1^T , k_2^T , \cdots , k_T^T \end{bmatrix}}{\sqrt{d}} + M\bigg ) \begin{bmatrix}v_1 \\ v_2 \\ \vdots \\ v_T \end{bmatrix} \tag{2} y1y2yT =softmax(d q1q2qT [k1T,k2T,,kTT]+M) v1v2vT (2)

**在生成阶段。**token是按序生成,在模型内部体现在 [ y 1 ; y 2 ; ⋯   ; y T ] \begin{bmatrix}y_1 ; y_2 ; \cdots ; y_T \end{bmatrix} [y1;y2;;yT]的每一行是依次输出的。

对于第一个token的生成只依赖 y 1 y_1 y1,第二个token的生成只依赖 y 2 y_2 y2,依次类推。

对每一个 y t y_t yt,attention的计算如下

在这里插入图片描述

y t = s o f t m a x ( q t [ k 1 T , ⋯   , k t T ] d ) [ v 1 ⋮ v t ] = s o f t m a x ( q t K ≤ t T d ) V ≤ t = ∑ i = 1 t s o f t m a x ( q t k i T d ) v i (3) \begin{aligned} y_t &= \mathrm{softmax} \bigg ( \frac{q_t\begin{bmatrix}k_1^T , \cdots , k_t^T \end{bmatrix}}{\sqrt{d}} \bigg ) \begin{bmatrix}v_1 \\ \vdots \\ v_t \end{bmatrix} \\ & = \mathrm{softmax} \bigg ( \frac{q_t \boxed{K^T_{\leq t}}}{\sqrt{d}} \bigg ) \boxed{V_{\leq t}} \\ & = \sum_{i=1}^{t} \mathrm{softmax} \bigg ( \frac{q_t k_i^T}{\sqrt{d}} \bigg ) v_i\\ \end{aligned} \tag{3} yt=softmax(d qt[k1T,,ktT]) v1vt =softmax(d qtKtT)Vt=i=1tsoftmax(d qtkiT)vi(3)

如果考虑位置编码,上式写改写为, P i ( ⋅ ) \mathcal{P}_i(\cdot) Pi()表示位置编码函数

y t = ∑ i = 1 t s o f t m a x ( P t ( q t ) P i ( k i T ) d ) v i (4) y_t = \sum_{i=1}^{t} \mathrm{softmax} \bigg ( \frac{\mathcal{P}_t(q_t)\mathcal{P}_i (k_i^T)}{\sqrt{d}} \bigg ) v_i \tag{4} yt=i=1tsoftmax(d Pt(qt)Pi(kiT))vi(4)

从上面的计算流程不难看出, y t y_t yt的生成只依赖当前 t t t位置的query,依赖前面所有位置 ≤ t \leq t t的key和value。为了得到 K ≤ t , K_{\leq t}, Kt, V ≤ t V_{\leq t} Vt,最naive的做法是:生成 t t t位置的token时,将 X ≤ t X_{\leq t} Xt作为Attention的输入,以此保证 K ≤ t , V ≤ t K_{\leq t},V_{\leq t} Kt,Vt能够被正确计算。Naive的做法没有kv-cache

但从上面的计算流程我们不难看出, y t y_t yt需要的 K ≤ t , V ≤ t K_{\leq t},V_{\leq t} Kt,Vt K ≤ t − 1 , V ≤ t − 1 K_{\leq t -1},V_{\leq t-1} Kt1,Vt1已经在 y t − 1 y_{t-1} yt1的计算中被计算。因此可以能把 y t − 1 y_{t-1} yt1算好的 K ≤ t − 1 , V ≤ t − 1 K_{\leq t -1},V_{\leq t-1} Kt1,Vt1保存起来,在 t t t位置只需计算 k t , v t k_t,v_t kt,vt,再与前面的 K ≤ t − 1 , V ≤ t − 1 K_{\leq t -1},V_{\leq t-1} Kt1,Vt1进行拼接就可以得到 K ≤ t , V ≤ t K_{\leq t},V_{\leq t} Kt,Vt。这样大大减少了冗余的计算量。这就是kv-cache的核心motivation。(公式中被“框起来”的部分是可以cache的。)

用kv-cache的生成思路,

生成第 1 1 1个token时,此时attention层输入 x 1 x_1 x1,输出 y 1 , ( k 1 , v 1 ) y_1, (k_1, v_1) y1,(k1,v1) ( k 1 , v 1 ) (k_1,v_1) (k1,v1)是缓存的kv-cache
生成第 2 2 2个token时,此时attention层输入 x 2 , ( k 1 , v 1 ) x_2,(k_1, v_1) x2,(k1,v1),输出 y 2 , ( K ≤ 2 , V ≤ 2 ) y_2, (K_{\leq 2}, V_{\leq2}) y2,(K2,V2)

生成第 t t t个token时,此时attention层输入 x t , ( K ≤ t − 1 , V ≤ t − 1 ) x_t,(K_{\leq t-1}, V_{\leq t-1}) xt,(Kt1,Vt1),输出 y t , ( K ≤ t , V ≤ t ) y_t, (K_{\leq t}, V_{\leq t}) yt,(Kt,Vt)

kv-cache能够显著降低attention的计算量,但随着生成token的增多,kv-cache所需的存储成本呈线性增加,导致GPU的显存成为生成长度的瓶颈。

3 Multi-Head Attention(MHA) KV-Cache工作机制

paper: https://arxiv.org/pdf/1706.03762

MHA是上面的一个推广。假定MHA的输入为 X = [ x 1 ; x 2 ; ⋯   ; x T ] , X ∈ R T × d X=[x_1;x_2;\cdots ;x_T], X\in \mathbb{R} ^ {T\times d} X=[x1;x2;;xT],XRT×d, T T T为sequence的长度。假定有 n h n_h nh个head,每个head投影的维度为 d h = d n h d_h=\frac{d}{n_h} dh=nhd
在这里插入图片描述

通过线性层的矩阵计算,得到不同head下的 Q i , K i , V i , i ∈ { 1 , ⋯   , n h } Q_i,K_i,V_i,i\in \{1,\cdots,n_h\} Qi,Ki,Vi,i{1,,nh}

q 1 ; i q_{1;i} q1;i表示head i i i query矩阵在序列位置为 1 1 1处的向量,其他符号记法类似。

在生成阶段每一个head经过attention计算后的 t t t位置的输出 y t ; i y_{t;i} yt;i如下( y 1 ; i , y 2 ; i , ⋯ y T ; i y_{1;i},y_{2;i}, \cdots y_{T;i} y1;iy2;iyT;i依序生成),

在这里插入图片描述

y t ; i = s o f t m a x ( q t ; i [ k 1 ; i T , ⋯   , k t ; i T ] d ) [ v 1 ; i ⋮ v t ; i ] = s o f t m a x ( q t ; i K ≤ t ; i T d ) V ≤ t ; i (5) \begin{aligned} y_{t;i} &= \mathrm{softmax} \bigg ( \frac{q_{t;i}\begin{bmatrix}k_{1;i}^T , \cdots , k_{t;i}^T \end{bmatrix}}{\sqrt{d}} \bigg ) \begin{bmatrix}v_{1;i} \\ \vdots \\ v_{t;i} \end{bmatrix} \\ & = \mathrm{softmax} \bigg ( \frac{q_{t;i} \boxed{K_{\leq t;i}^T} }{\sqrt{d}} \bigg ) \boxed{V_{\leq t;i}} \\ \end{aligned} \tag{5} yt;i=softmax(d qt;i[k1;iT,,kt;iT]) v1;ivt;i =softmax(d qt;iKt;iT)Vt;i(5)

循环Attention head i = 1 , ⋯   , n h i=1,\cdots,n_h i=1,,nh ,可以计算所有head t t t时刻的输出 { y t ; 1 , ⋯   , y t ; n h } \{y_{t;1}, \cdots ,y_{t;n_h}\} {yt;1,,yt;nh}最后将不同head的输出 y t ; i y_{t;i} yt;i进行拼接,得到最终的输出

y t = C a t ( [ y t ; 1 , y t ; 2 , ⋯   , y t ; n h ] , d i m = 1 ) y t ∈ R 1 × d (6) y_{t} = \mathrm{Cat}([y_{t;1}, y_{t;2}, \cdots ,y_{t;n_h}], \mathrm{dim}=1) \\y_t \in \mathbb{R}^{1 \times d} \tag{6} yt=Cat([yt;1,yt;2,,yt;nh],dim=1)ytR1×d(6)

对于MHA而言, y t y_t yt的计算缓存的kv-cache为 { K ≤ t ; i , V ≤ t ; i ∣ i = 1 , 2 , ⋯   , n h } \{K_{\leq t;i} ,V_{\leq t;i} |i=1,2,\cdots ,n_h\} {Kt;i,Vt;ii=1,2,,nh}

4 Group Query Attention(GQA) KV-Cache工作机制

paper: https://arxiv.org/pdf/2305.13245

GQA的attention计算机制与MHA一致。有所区别的是,GQA为了降低KV-Cache的存储,将attention的head分为了 n g n_g ng组,同一组共享kv-cache
在这里插入图片描述

g i = ⌈ ( i n g ) ⌉ (7) g_i = \lceil (\frac{i}{n_g} ) \rceil\tag{7} gi=⌈(ngi)⌉(7)

⌈ ⋅ ⌉ \lceil \cdot \rceil 是向上取整符号。若 n g = 4 n_g = 4 ng=4,那么 i ∈ { 1 , 2 , 3 , 4 } i \in \{ 1,2,3,4 \} i{1,2,3,4}共享 g i = 1 g_i = 1 gi=1这个group的key,value。

同样,在生成阶段 y 1 ; i , y 2 ; i , ⋯ y T ; i y_{1;i},y_{2;i}, \cdots y_{T;i} y1;iy2;iyT;i依序生成。每一个head经过attention计算后的 t t t位置的输出 y t ; i y_{t;i} yt;i如下,

y t ; i = s o f t m a x ( q t ; i [ k 1 ; g i T , ⋯   , k t ; g i T ] d ) [ v 1 ; g i ⋮ v t ; g i ] = s o f t m a x ( q t ; i K ≤ t ; g i T d ) V ≤ t ; g i (8) \begin{aligned} y_{t;i} &= \mathrm{softmax} \bigg ( \frac{q_{t;i}\begin{bmatrix}k_{1;g_i}^T , \cdots , k_{t;g_i}^T \end{bmatrix}}{\sqrt{d}} \bigg ) \begin{bmatrix}v_{1;g_i} \\ \vdots \\ v_{t;g_i} \end{bmatrix} \\ & = \mathrm{softmax} \bigg ( \frac{q_{t;i} \boxed{K_{\leq t;g_i}^T} }{\sqrt{d}} \bigg ) \boxed{ V_{\leq t;g_i} } \\ \end{aligned} \tag{8} yt;i=softmax(d qt;i[k1;giT,,kt;giT]) v1;givt;gi =softmax(d qt;iKt;giT)Vt;gi(8)

Loop Attention head i = 1 , ⋯   , n h i=1,\cdots,n_h i=1,,nh ,可以计算所有head t t t时刻的输出 { y t ; 1 , ⋯   , y t ; n h } \{y_{t;1}, \cdots ,y_{t;n_h}\} {yt;1,,yt;nh}最后将不同head的输出 y t ; i y_{t;i} yt;i进行拼接,得到最终的输出

y t = C a t ( [ y t ; 1 , y t ; 2 , ⋯   , y t ; n h ] , d i m = 1 ) y t ∈ R 1 × d (9) y_{t} = \mathrm{Cat}([y_{t;1}, y_{t;2}, \cdots ,y_{t;n_h}], \mathrm{dim}=1) \\y_t \in \mathbb{R}^{1 \times d} \tag{9} yt=Cat([yt;1,yt;2,,yt;nh],dim=1)ytR1×d(9)

对于GQA而言, y t y_t yt的计算缓存的kv-cache为 { K ≤ t ; g i , V ≤ t ; g i ∣ i = 1 , 2 , ⋯   , n g } \{K_{\leq t;g_i} ,V_{\leq t;g_i} |i=1,2,\cdots ,n_g\} {Kt;gi,Vt;gii=1,2,,ng},相比标准的MHA,KV-cache降低了 n h n g \frac{n_h}{n_g} ngnh

5 Multi Query Attention(MQA) KV-Cache工作机制

paper: https://arxiv.org/pdf/1911.02150

MQAGQA的一个特例。当 n g = 1 n_g=1 ng=1时,即所有head的query共享同一组key, value,此时的GQA成为MQA

对于MQA而言, y t y_t yt的计算缓存的kv-cache为 { K ≤ t ; g i , V ≤ t ; g i ∣ g i = 1 } \{K_{\leq t;g_i} ,V_{\leq t;g_i} |g_i=1\} {Kt;gi,Vt;gigi=1},相比标准的MHA,KV-cache降低了 n h n_h nh倍。

6 Multi-Head Latent Attention (MLA)的工作机制

MLAdeepseek提出的一项针对kv-cache的优化。

paper: https://arxiv.org/abs/2405.04434

(一)先抛开位置编码

假定MLA的输入为 X = [ x 1 ; x 2 ; ⋯   ; x T ] , X ∈ R T × d X=[x_1;x_2;\cdots ;x_T], X\in \mathbb{R} ^ {T\times d} X=[x1;x2;;xT],XRT×d, T T T为sequence的长度。假定有 n h n_h nh个head,每个head投影的维度为 d n h \frac{d}{n_h} nhd。初步来看未引入位置编码的MLA像是引入了一个低秩分解矩阵(类似LoRA的做法)的MHA

在这里插入图片描述

head i i i的Q,K,V计算过程如下

[ Q i K i V i ] = [ X W q ; i X W k ; i X W v ; i ] ⏟ M H A → [ Q i K i V i ] = [ X W q / A W q / B ; i X W k v / A W k / B ; i X W k v / A W v / B ; i ] = [ C q W q / B ; i C k v W k / B ; i C k v W v / B ; i ] ⏟ M L A (10) \underbrace{ \begin{bmatrix} Q_i \\K_i\\V_i \end{bmatrix} =\begin{bmatrix} XW_{q;i} \\ XW_{k;i}\\ XW_{v;i} \end{bmatrix} }_{MHA} \rightarrow \underbrace{ \begin{bmatrix} Q_i \\K_i\\V_i \end{bmatrix}= \begin{bmatrix} XW_{q/A} W_{q/B;i}\\ XW_{kv/A}W_{k/B;i}\\ XW_{kv/A}W_{v/B;i} \end{bmatrix}= \begin{bmatrix} C^{q}W_{q/B;i}\\ C^{kv}W_{k/B;i}\\ C^{kv}W_{v/B;i} \end{bmatrix} }_{MLA} \tag{10} MHA QiKiVi = XWq;iXWk;iXWv;i MLA QiKiVi = XWq/AWq/B;iXWkv/AWk/B;iXWkv/AWv/B;i = CqWq/B;iCkvWk/B;iCkvWv/B;i (10)

矩阵维度变换说明

dimension
W k v / A W_{kv/A} Wkv/A (下标kv表示key-value的compress latent编码投影矩阵,/A类比LORA的A矩阵) R d × d c \mathbb{R} ^{d \times d_c} Rd×dc
W k / B ; i , W v / B ; i W_{k/B;i},W_{v/B;i} Wk/B;i,Wv/B;i/B 类比LORA的B矩阵, $;i$)表示attention第 i i i个head R d c × d n h \mathbb{R} ^{d_c \times \frac{d}{n_h}} Rdc×nhd
W q / A ; i W_{q/A;i} Wq/A;i R d × d c ′ \mathbb{R} ^{ d \times d'_c} Rd×dc
W q / B ; i W_{q/B;i} Wq/B;i R d n h × d c \mathbb{R} ^{\frac{d}{n_h} \times d_c} Rnhd×dc
C q C^q Cq (上标q表示 query的compress latent code) R T × d c ′ \mathbb{R} ^ {T\times d_c'} RT×dc
C k v C^{kv} Ckv (表示 key-value的compress latent code) R T × d c \mathbb{R} ^ {T\times d_c} RT×dc

在生成阶段,每一个head经过attention计算后的 t t t位置的输出 y t ; i y_{t;i} yt;i如下( y 1 ; i , y 2 ; i , ⋯ y T ; i y_{1;i},y_{2;i}, \cdots y_{T;i} y1;iy2;iyT;i依序生成)

y t ; i =   s o f t m a x ( q t ; i K ≤ t ; i T d ) V ≤ t ; i = s o f t m a x ( q t ; i W k / B ; i T ( C ≤ t k v ) T d ) C ≤ t k v W v / B ; i \begin{align*} y_{t;i} &= \ \mathrm{softmax} \bigg ( \frac{q_{t;i} K^T_{\leq t;i}}{\sqrt{d}} \bigg ) V_{\leq t;i} \tag{9} \\ & = \mathrm{softmax} \bigg ( \frac{q_{t;i} W_{k/B;i}^{T}(\boxed{C^{kv}_{\leq t}})^T}{\sqrt{d}} \bigg ) \boxed{C^{kv}_{\leq t}}W_{v/B;i}\tag{10} \end{align*} yt;i= softmax(d qt;iKt;iT)Vt;i=softmax(d qt;iWk/B;iT(Ctkv)T)CtkvWv/B;i(9)(10)

式(9)和MHA的generate阶段的形式相同,当然可以通过缓存 { K ≤ t ; i , V ≤ t ; i ∣ i = 1 , 2 , ⋯   , n h } \{K_{\leq t;i} ,V_{\leq t;i} |i=1,2,\cdots ,n_h\} {Kt;i,Vt;ii=1,2,,nh}实现kv-cache。

但MLA提供了一个新的方法(式10),只需要缓存 C ≤ t k v ∈ R t × d c C^{kv}_{\leq t} \in \mathbb{R}^{t \times d_c} CtkvRt×dc即可,相比原始方法的kv-cache的存储单元数量从 t × d n h × n h × 2 = 2 d t t\times \frac{d}{n_h} \times n_h \times 2 = 2dt t×nhd×nh×2=2dt降低到 t × d c t\times d_c t×dc但这个方法需要引入两个矩阵乘法的计算量。因为 d c d_c dc不大,引入的计算量是可以接受的。(还有一种更为巧妙的方法能规避计算量增加的问题,在(二)中介绍)

(二)引入位置编码的MLA

这个形式也是deepseek论文中提出的形式。有了上面的基础,再理解就很简单了。同样假定MLA的输入为 X = [ x 1 ; x 2 ; ⋯   ; x T ] , X ∈ R T × d X=[x_1;x_2;\cdots ;x_T], X\in \mathbb{R} ^ {T\times d} X=[x1;x2;;xT],XRT×d, T T T为sequence的长度。假定有 n h n_h nh个head,每个head投影的维度为 d h = d n h d_h=\frac{d}{n_h} dh=nhd

在这里插入图片描述

head i i i的Q,K,V计算过程如下

[ Q i K i V i ] ⇒ [ X W q ; i X W k ; i X W v ; i ] ⏟ M H A ⇒ [ X W q / A W q / B ; i X W k v / A W k / B ; i X W k v / A W v / B ; i ] = [ C q W q / B ; i C k v W k / B ; i C k v W v / B ; i ] ⏟ MLA, no pos embedding ⇒ [ Q i C , Q i R K i C , K R V i ] = [ C q W q C / B ; i , R O P E ( C q W q R / B ; i ) C k v W k C / B ; i , R O P E ( X W k R ) C k v W v / B ; i ] ⏟ MLA with pos embedding (11) \begin{aligned} \begin{bmatrix} Q_i \\K_i\\V_i \end{bmatrix} & \Rightarrow \underbrace{\begin{bmatrix} XW_{q;i} \\ XW_{k;i}\\ XW_{v;i} \end{bmatrix} }_{MHA} \Rightarrow \underbrace{ \begin{bmatrix} XW_{q/A} W_{q/B;i}\\ XW_{kv/A}W_{k/B;i}\\ XW_{kv/A}W_{v/B;i} \end{bmatrix} = \begin{bmatrix} C^{q}W_{q/B;i}\\ C^{kv}W_{k/B;i}\\ C^{kv}W_{v/B;i} \end{bmatrix} }_{\text{MLA, no pos embedding}} \\ & \Rightarrow \underbrace{ \begin{bmatrix} Q_i^C,Q_i^R \\K_i^C, K^R\\V_i \end{bmatrix} = \begin{bmatrix} C^qW_{q^C/B;i},\mathrm{ROPE} (C^qW_{q^R/B;i}) \\ C^{kv}W_{k^C/B;i}, \mathrm{ROPE} (XW_{k^R})\\ C^{kv}W_{v/B;i} \end{bmatrix} }_{\text{MLA with pos embedding}} \end{aligned} \tag{11} QiKiVi MHA XWq;iXWk;iXWv;i MLA, no pos embedding XWq/AWq/B;iXWkv/AWk/B;iXWkv/AWv/B;i = CqWq/B;iCkvWk/B;iCkvWv/B;i MLA with pos embedding QiC,QiRKiC,KRVi = CqWqC/B;i,ROPE(CqWqR/B;i)CkvWkC/B;i,ROPE(XWkR)CkvWv/B;i (11)

矩阵维度变换说明

dimension
W k v / A W_{kv/A} Wkv/A(下标kv表示key-value的compress latent编码投影矩阵,/A类比LORA的A矩阵) R d × d c \mathbb{R} ^{d \times d_c} Rd×dc
W k C / B ; i , W v / B ; i W_{k^C/B;i},W_{v/B;i} WkC/B;i,Wv/B;i (/B 类比LORA的B矩阵, $;i$)表示attention第 i i i个head R d c × d h \mathbb{R} ^{d_c \times d_h} Rdc×dh
W k R W_{k^R} WkR R d × d h R \mathbb{R} ^{d \times d_h^R} Rd×dhR
W q / A ; i W_{q/A;i} Wq/A;i R d × d c ′ \mathbb{R} ^{d \times d'_c} Rd×dc
W q C / B ; i W_{q^C/B;i} WqC/B;i R d c ′ × d h \mathbb{R} ^{d'_c \times d_h} Rdc×dh
W q R / B ; i W_{q^R/B;i} WqR/B;i R d c ′ × d h R \mathbb{R} ^{d_c' \times d^R_h } Rdc×dhR
C q C^q Cq (上标q表示 query的compress latent code) R T × d c ′ \mathbb{R} ^ {T\times d_c'} RT×dc
C k v C^{kv} Ckv(表示 key-value的compress latent code) R T × d c \mathbb{R} ^ {T\times d_c} RT×dc
Q i C Q_{i}^{C} QiC (上标C表示compression的首字母“C”) R T × d h \mathbb{R}^{T\times d_h} RT×dh (不含位置编码的query)
Q i R Q_{i}^{R} QiR (上标R是RoPE的R R T × d h R \mathbb{R}^{T\times d_h^R} RT×dhR (包含位置编码的query)

在这里插入图片描述

y t ; i = s o f t m a x ( [ q t ; i C , q t ; i R ] [ ( k 1 ; i C ) T , ⋯   , ( k t ; i C ) T ( k 1 R ) T , ⋯   , ( k t R ) T ] d h + d h R ) V ≤ t ; i = s o f t m a x ( [ q t ; i C , q t ; i R ] [ ( K ≤ t ; i C ) T ( K ≤ t R ) T ] d h + d h R ) V ≤ t ; i = s o f t m a x ( [ q t ; i C , q t ; i R ] [ ( C ≤ t k v W k C / B ; i ) T ( K ≤ t R ) T ] d h + d h R ) C ≤ t k v W v / B ; i \begin{align*} y_{t;i} &= \mathrm{softmax} \bigg ( \frac{ \begin{bmatrix} q^C_{t;i}, q^R_{t;i} \end{bmatrix} \begin{bmatrix} (k^C_{1;_i})^T , \cdots , (k^C_{t;i})^T \\ (k^R_{1})^T , \cdots , (k^R_{t})^T \end{bmatrix} }{\sqrt{d_h+d_h^R}} \bigg ) V_{\leq t;i} \tag{12} \\ &= \mathrm{softmax} \bigg ( \frac{ \begin{bmatrix} q^C_{t;i}, q^R_{t;i} \end{bmatrix} \begin{bmatrix} (K^C_{\leq t;i})^T \\ (K^R_{\leq t})^T \end{bmatrix} }{\sqrt{d_h+d_h^R}} \bigg ) V_{\leq t;i} \tag{13} \\ & = \mathrm{softmax} \bigg ( \frac{ \begin{bmatrix} q^C_{t;i}, q^R_{t;i} \end{bmatrix} \begin{bmatrix} ( \boxed{ C^{kv}_{\le t}} W_{k^C/B;i})^T \\ (\boxed{ K^R_{\leq t}})^T \end{bmatrix} }{\sqrt{d_h+d_h^R}} \bigg ) \boxed {C^{kv}_{\leq t} } W_{v/B;i}\tag{14} \end{align*} yt;i=softmax(dh+dhR [qt;iC,qt;iR][(k1;iC)T,,(kt;iC)T(k1R)T,,(ktR)T])Vt;i=softmax(dh+dhR [qt;iC,qt;iR][(Kt;iC)T(KtR)T])Vt;i=softmax(dh+dhR [qt;iC,qt;iR] (CtkvWkC/B;i)T(KtR)T )CtkvWv/B;i(12)(13)(14)

从式14可见,加了位置编码的MLA相比无位置编码的情形多缓存了一个 K ≤ t R K^R_{\leq t} KtR。这里需要注意,对于所有attention head K ≤ t R K^R_{\leq t} KtR是共享的(类似MQA)。

此时KV-cache需要缓存 { C ≤ t k v , K ≤ t R } \{C^{kv}_{\leq t}, K^R_{\leq t}\} {Ctkv,KtR} 的存储单元数量为 t ( d c + d h R ) t (d_c + d_h^R) t(dc+dhR)

**与不加位置编码的情形一致,这个方法推理时需要引入两个矩阵乘法的计算量, C ≤ t k v W k C / B ; i C^{kv}_{\le t} W_{k^C/B;i} CtkvWkC/B;i C ≤ t k v W v / B ; i C^{kv}_{\leq t} W_{v/B;i} CtkvWv/B;i。**不妨对式(14)再次进行变形,可以看到MLA 又一巧妙的设计

y t ; i = s o f t m a x ( [ q t ; i C , q t ; i R ] [ ( C ≤ t k v W k C / B ; i ) T ( K ≤ t R ) T ] d h + d h R ) C ≤ t k v W v / B ; i = s o f t m a x ( q t ; i C ( W k C / B ; i ) T ( C ≤ t k v ) T + q t ; i R ( K ≤ t R ) T d h + d h R ) C ≤ t k v W v / B ; i \begin{align*} y_{t;i} &= \mathrm{softmax} \bigg ( \frac{ \begin{bmatrix} q^C_{t;i}, q^R_{t;i} \end{bmatrix} \begin{bmatrix} ( \boxed{ C^{kv}_{\le t}} W_{k^C/B;i})^T \\ (\boxed{ K^R_{\leq t}})^T \end{bmatrix} }{\sqrt{d_h+d_h^R}} \bigg ) \boxed {C^{kv}_{\leq t} } W_{v/B;i}\\ &= \mathrm{softmax} \bigg ( \frac{ q^C_{t;i}(W_{k^C/B;i})^T(\boxed{ C^{kv}_{\le t}})^T+ q^R_{t;i} ( \boxed{K^R_{\leq t}})^T }{\sqrt{d_h+d_h^R}} \bigg ) \boxed {C^{kv}_{\leq t} } W_{v/B;i} \tag{15}\\ \end{align*} yt;i=softmax(dh+dhR [qt;iC,qt;iR] (CtkvWkC/B;i)T(KtR)T )CtkvWv/B;i=softmax(dh+dhR qt;iC(WkC/B;i)T(Ctkv)T+qt;iR(KtR)T)CtkvWv/B;i(15)

从式(15)可见,在推理时, W q C / B ; i ( W k C / B ; i ) T W_{q^C/B;i}(W_{k^C/B;i})^T WqC/B;i(WkC/B;i)T可以预先合并为1个矩阵,同理 W v / B ; i W_{v/B;i} Wv/B;i 也可以和随后的线性层的权重进行合并。但计算量的降低主要与矩阵乘法的计算顺序改变导致:

计算次序element-level乘法执行次数
q t ; i C ⏟ 2 ◯ ( C ≤ t k v W k C / B ; i ) T ⏟ 1 ◯ \underbrace{q^C_{t;i}}_{\textcircled{2}}\underbrace{( C^{kv}_{\le t} W_{k^C/B;i})^T}_{\textcircled{1}} 2 qt;iC1 (CtkvWkC/B;i)T t × d c × d h ⏟ 1 ◯ + t × d h × 1 ⏟ 2 ◯ = ( d c + 1 ) t d h \underbrace{t \times d_c \times d_h}_{\textcircled{1}} + \underbrace{t \times d_h \times 1}_{\textcircled{2}} = (d_c + 1)td_h 1 t×dc×dh+2 t×dh×1=(dc+1)tdh
q t ; i C ( W k C / B ; i ) T ⏟ 1 ◯ ( C ≤ t k v ) T ⏟ 2 ◯ \underbrace{q^C_{t;i}(W_{k^C/B;i})^T}_{\textcircled{1}}\underbrace{( C^{kv}_{\le t})^T }_{\textcircled{2}} 1 qt;iC(WkC/B;i)T2 (Ctkv)T 1 × d h × d c ⏟ 1 ◯ + 1 × d c × t ⏟ 2 ◯ = ( d h + t ) d c \underbrace{1 \times d_h \times d_c}_{\textcircled{1}} + \underbrace{1 \times d_c \times t}_{\textcircled{2}} = (d_h + t)d_c 1 1×dh×dc+2 1×dc×t=(dh+t)dc

维度说明: q t ; i C ∈ R 1 × d h , C ≤ t k v ∈ R t × d c , W k C / B ; i ∈ R d c × d h q^C_{t;i} \in \mathbb{R}^{1 \times d_h}, C^{kv}_{\leq t} \in \mathbb{R}^ {t \times d_c}, W_{k^C/B;i} \in \mathbb{R}^{d_c \times d_h} qt;iCR1×dh,CtkvRt×dc,WkC/B;iRdc×dh

7 小结

文本详细介绍了kv-cache原理,及降低kv-cache存储成本目前常用的MQA,GQA,MLA方法。如有疏漏之处,敬请指出。

image from deepseekv2 tech report
不同attention方法 KV cache的存储单元数量

KV-cache存储单元数量
Casual Attention K ≤ t , V ≤ t K_{\leq t},V_{\leq t} Kt,Vt 2 t d 2td 2td
MHA { K ≤ t ; i , V ≤ t ; i ∣ i = 1 , 2 , ⋯   , n h } \{K_{\leq t;i} ,V_{\leq t;i} |i=1,2,\cdots ,n_h\} {Kt;i,Vt;ii=1,2,,nh} 2 t n h d h 2 t n_hd_h 2tnhdh
GQA { K ≤ t ; g i , V ≤ t ; g i ∣ i = 1 , 2 , ⋯   , n g } \{K_{\leq t;g_i} ,V_{\leq t;g_i} |i=1,2,\cdots ,n_g\} {Kt;gi,Vt;gii=1,2,,ng} 2 t n g d h 2 t n_gd_h 2tngdh
MQA { K ≤ t ; g i , V ≤ t ; g i ∣ g i = 1 } \{K_{\leq t;g_i} ,V_{\leq t;g_i} |g_i=1\} {Kt;gi,Vt;gigi=1} 2 t d h 2 t d_h 2tdh
MLA C ≤ t k v , K ≤ t R C^{kv}_{\leq t} ,K^R_{\leq t} Ctkv,KtR t ( d c + d h R ) t (d_c + d_h^R) t(dc+dhR)
### MHAGQAMLA 的区别及应用场合 #### 多头注意力机制(Multi-Head Attention, MHA) 多头注意力机制允许模型在不同的表示子空间中并行关注不同位置的信息。每个头独立操作,最终结果通过拼接各头的结果来获得更丰富的特征表达[^1]。 ```python import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, embed_size, num_heads): super(MultiHeadAttention, self).__init__() self.embed_size = embed_size self.num_heads = num_heads # 定义线性变换层 self.query_linear = nn.Linear(embed_size, embed_size) self.key_linear = nn.Linear(embed_size, embed_size) self.value_linear = nn.Linear(embed_size, embed_size) def forward(self, query, key, value): batch_size = query.size(0) # 对输入进行线性变换 Q = self.query_linear(query) K = self.key_linear(key) V = self.value_linear(value) # 将嵌入维度分割成多个头 Q = Q.view(batch_size, -1, self.num_heads, self.embed_size // self.num_heads).transpose(1, 2) K = K.view(batch_size, -1, self.num_heads, self.embed_size // self.num_heads).transpose(1, 2) V = V.view(batch_size, -1, self.num_heads, self.embed_size // self.num_heads).transpose(1, 2) # 计算注意力分数并加权求和 scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.embed_size ** 0.5) attention_weights = F.softmax(scores, dim=-1) output = torch.matmul(attention_weights, V) # 合并头部并将结果传递给下一个线性层 output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_size) return output ``` #### 组化查询注意力机制(Grouped Query Attention, GQA) 为了减少计算量,GQA引入了查询分组的概念,即某些查询可以共享相同的键和值矩阵。这减少了重复计算的数量,在大规模数据集上尤其有效[^3]。 ```python class GroupedQueryAttention(nn.Module): def __init__(self, embed_size, num_groups, heads_per_group): super(GroupedQueryAttention, self).__init__() self.embed_size = embed_size self.num_groups = num_groups self.heads_per_group = heads_per_group # 初始化参数... def forward(self, queries, keys, values): # 实现GQA的具体逻辑... pass ``` #### 压缩键值注意力机制(Compressed Key/Value Attention, MLAMLA进一步优化了资源利用效率,通过对键和值向量实施低秩近似压缩处理,从而显著降低了存储开销以及前向传播过程中的运算复杂度。 $$K_{\text{compressed}} = U_K \cdot S_K \cdot V_K^T$$ $$V_{\text{compressed}} = U_V \cdot S_V \cdot V_V^T$$ ```python from scipy.linalg import svd def compress_matrix(matrix, rank): u, s, vh = svd(matrix) compressed = np.dot(u[:, :rank], np.dot(np.diag(s[:rank]), vh[:rank, :])) return compressed class CompressedKeyValueAttention(nn.Module): def __init__(self, embed_size, compression_rank): super(CompressedKeyValueAttention, self).__init__() self.compress_key = lambda k: compress_matrix(k, compression_rank) self.compress_value = lambda v: compress_matrix(v, compression_rank) # 其他初始化... def forward(self, queries, keys, values): compressed_keys = self.compress_key(keys) compressed_values = self.compress_value(values) # 使用压缩后的keys和values继续执行标准的注意力机制... pass ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值