个人博客位置: http://myhz0606.com/article/kv-cache
1 背景
KV-cache技术是目前LLM
,VLLM
等自回归模型常用的避免冗余计算的手段。但引入该技术需要额外的存储成本。原生的kv-cache所需的存储成本与生成的token长度成正比,是目前长文本生成的主要瓶颈之一。目前针对如何降低KV-cache的存储成本激起大量研究者广泛关注。GQA
(Group Query Attention),MQA
(Multi Query Attention),MLA
(Multi-Head Latent Attention)是目前常用的方法。本文将从经典的casual attention出发,阐述kv-cache的必要性,及目前常见优化kv-cache的手段。
TL,DR
不同attention方法 KV cache的存储单元数量
KV-cache | 存储单元数量 | |
---|---|---|
Casual Attention | K ≤ t , V ≤ t K_{\leq t},V_{\leq t} K≤t,V≤t | 2 t d 2td 2td |
MHA | { K ≤ t ; i , V ≤ t ; i ∣ i = 1 , 2 , ⋯ , n h } \{K_{\leq t;i} ,V_{\leq t;i} |i=1,2,\cdots ,n_h\} {K≤t;i,V≤t;i∣i=1,2,⋯,nh} | 2 t n h d h 2 t n_hd_h 2tnhdh |
GQA | { K ≤ t ; g i , V ≤ t ; g i ∣ i = 1 , 2 , ⋯ , n g } \{K_{\leq t;g_i} ,V_{\leq t;g_i} |i=1,2,\cdots ,n_g\} {K≤t;gi,V≤t;gi∣i=1,2,⋯,ng} | 2 t n g d h 2 t n_gd_h 2tngdh |
MQA | { K ≤ t ; g i , V ≤ t ; g i ∣ g i = 1 } \{K_{\leq t;g_i} ,V_{\leq t;g_i} |g_i=1\} {K≤t;gi,V≤t;gi∣gi=1} | 2 t d h 2 t d_h 2tdh |
MLA | C ≤ t k v , K ≤ t R C^{kv}_{\leq t} ,K^R_{\leq t} C≤tkv,K≤tR | t ( d c + d h R ) t (d_c + d_h^R) t(dc+dhR) |
2 经典Casual-Attention的KV-Cache工作机制
假定当前层attention的输入为 X = [ x 1 ; x 2 ; ⋯ ; x T ] , X ∈ R T × d X=[x_1;x_2;\cdots ;x_T], X\in \mathbb{R} ^ {T\times d} X=[x1;x2;⋯;xT],X∈RT×d, T T T为sequence的长度。通过 W q , W k , W v W_q,W_k,W_v Wq,Wk,Wv 3个线性层得到query ,key,value。
Q = [ q 1 q 2 ⋮ q T ] , K = [ k 1 k 2 ⋮ k T ] , V = [ v 1 v 2 ⋮ v T ] q t , k t , v t ∈ R 1 × d (1) Q=\begin{aligned} &\begin{bmatrix}q_1 \\ q_2 \\ \vdots \\ q_T \end{bmatrix}, K=\begin{bmatrix}k_1 \\ k_2 \\ \vdots \\ k_T \end{bmatrix}, V=\begin{bmatrix}v_1 \\ v_2 \\ \vdots \\ v_T \end{bmatrix} \\ & q_t,k_t,v_t \in \mathbb{R}^{1\times d} \end{aligned} \tag{1} Q= q1q2⋮qT ,K= k1k2⋮kT ,V= v1v2⋮vT qt,kt,vt∈R1×d(1)
随后通过标准的casual attention机制,得到输出 Y Y Y。
在训练阶段,为了通过teaching forcing技巧进行并行化训练,引入一个causal mask M M M,来保证 t t t位置的token只看的到 ≤ t \leq t ≤t的token。这个阶段没有kv-cache。
[
y
1
y
2
⋮
y
T
]
=
s
o
f
t
m
a
x
(
[
q
1
q
2
⋮
q
T
]
[
k
1
T
,
k
2
T
,
⋯
,
k
T
T
]
d
+
M
)
[
v
1
v
2
⋮
v
T
]
(2)
\begin{bmatrix}y_1 \\ y_2 \\ \vdots \\ y_T \end{bmatrix} = \mathrm{softmax} \bigg ( \frac{\begin{bmatrix}q_1 \\ q_2 \\ \vdots \\ q_T \end{bmatrix}\begin{bmatrix}k_1^T , k_2^T , \cdots , k_T^T \end{bmatrix}}{\sqrt{d}} + M\bigg ) \begin{bmatrix}v_1 \\ v_2 \\ \vdots \\ v_T \end{bmatrix} \tag{2}
y1y2⋮yT
=softmax(d
q1q2⋮qT
[k1T,k2T,⋯,kTT]+M)
v1v2⋮vT
(2)
**在生成阶段。**token是按序生成,在模型内部体现在 [ y 1 ; y 2 ; ⋯ ; y T ] \begin{bmatrix}y_1 ; y_2 ; \cdots ; y_T \end{bmatrix} [y1;y2;⋯;yT]的每一行是依次输出的。
对于第一个token的生成只依赖 y 1 y_1 y1,第二个token的生成只依赖 y 2 y_2 y2,依次类推。
对每一个 y t y_t yt,attention的计算如下
y t = s o f t m a x ( q t [ k 1 T , ⋯ , k t T ] d ) [ v 1 ⋮ v t ] = s o f t m a x ( q t K ≤ t T d ) V ≤ t = ∑ i = 1 t s o f t m a x ( q t k i T d ) v i (3) \begin{aligned} y_t &= \mathrm{softmax} \bigg ( \frac{q_t\begin{bmatrix}k_1^T , \cdots , k_t^T \end{bmatrix}}{\sqrt{d}} \bigg ) \begin{bmatrix}v_1 \\ \vdots \\ v_t \end{bmatrix} \\ & = \mathrm{softmax} \bigg ( \frac{q_t \boxed{K^T_{\leq t}}}{\sqrt{d}} \bigg ) \boxed{V_{\leq t}} \\ & = \sum_{i=1}^{t} \mathrm{softmax} \bigg ( \frac{q_t k_i^T}{\sqrt{d}} \bigg ) v_i\\ \end{aligned} \tag{3} yt=softmax(dqt[k1T,⋯,ktT]) v1⋮vt =softmax(dqtK≤tT)V≤t=i=1∑tsoftmax(dqtkiT)vi(3)
如果考虑位置编码,上式写改写为, P i ( ⋅ ) \mathcal{P}_i(\cdot) Pi(⋅)表示位置编码函数
y t = ∑ i = 1 t s o f t m a x ( P t ( q t ) P i ( k i T ) d ) v i (4) y_t = \sum_{i=1}^{t} \mathrm{softmax} \bigg ( \frac{\mathcal{P}_t(q_t)\mathcal{P}_i (k_i^T)}{\sqrt{d}} \bigg ) v_i \tag{4} yt=i=1∑tsoftmax(dPt(qt)Pi(kiT))vi(4)
从上面的计算流程不难看出, y t y_t yt的生成只依赖当前 t t t位置的query,依赖前面所有位置 ≤ t \leq t ≤t的key和value。为了得到 K ≤ t , K_{\leq t}, K≤t, V ≤ t V_{\leq t} V≤t,最naive的做法是:生成 t t t位置的token时,将 X ≤ t X_{\leq t} X≤t作为Attention的输入,以此保证 K ≤ t , V ≤ t K_{\leq t},V_{\leq t} K≤t,V≤t能够被正确计算。Naive的做法没有kv-cache。
但从上面的计算流程我们不难看出, y t y_t yt需要的 K ≤ t , V ≤ t K_{\leq t},V_{\leq t} K≤t,V≤t中 K ≤ t − 1 , V ≤ t − 1 K_{\leq t -1},V_{\leq t-1} K≤t−1,V≤t−1已经在 y t − 1 y_{t-1} yt−1的计算中被计算。因此可以能把 y t − 1 y_{t-1} yt−1算好的 K ≤ t − 1 , V ≤ t − 1 K_{\leq t -1},V_{\leq t-1} K≤t−1,V≤t−1保存起来,在 t t t位置只需计算 k t , v t k_t,v_t kt,vt,再与前面的 K ≤ t − 1 , V ≤ t − 1 K_{\leq t -1},V_{\leq t-1} K≤t−1,V≤t−1进行拼接就可以得到 K ≤ t , V ≤ t K_{\leq t},V_{\leq t} K≤t,V≤t。这样大大减少了冗余的计算量。这就是kv-cache的核心motivation。(公式中被“框起来”的部分是可以cache的。)
用kv-cache的生成思路,
生成第
1
1
1个token时,此时attention层输入
x
1
x_1
x1,输出
y
1
,
(
k
1
,
v
1
)
y_1, (k_1, v_1)
y1,(k1,v1)。
(
k
1
,
v
1
)
(k_1,v_1)
(k1,v1)是缓存的kv-cache
生成第
2
2
2个token时,此时attention层输入
x
2
,
(
k
1
,
v
1
)
x_2,(k_1, v_1)
x2,(k1,v1),输出
y
2
,
(
K
≤
2
,
V
≤
2
)
y_2, (K_{\leq 2}, V_{\leq2})
y2,(K≤2,V≤2)
…
生成第
t
t
t个token时,此时attention层输入
x
t
,
(
K
≤
t
−
1
,
V
≤
t
−
1
)
x_t,(K_{\leq t-1}, V_{\leq t-1})
xt,(K≤t−1,V≤t−1),输出
y
t
,
(
K
≤
t
,
V
≤
t
)
y_t, (K_{\leq t}, V_{\leq t})
yt,(K≤t,V≤t)
…
kv-cache能够显著降低attention的计算量,但随着生成token的增多,kv-cache所需的存储成本呈线性增加,导致GPU的显存成为生成长度的瓶颈。
3 Multi-Head Attention(MHA
) KV-Cache工作机制
paper: https://arxiv.org/pdf/1706.03762
MHA是上面的一个推广。假定MHA的输入为
X
=
[
x
1
;
x
2
;
⋯
;
x
T
]
,
X
∈
R
T
×
d
X=[x_1;x_2;\cdots ;x_T], X\in \mathbb{R} ^ {T\times d}
X=[x1;x2;⋯;xT],X∈RT×d,
T
T
T为sequence的长度。假定有
n
h
n_h
nh个head,每个head投影的维度为
d
h
=
d
n
h
d_h=\frac{d}{n_h}
dh=nhd
通过线性层的矩阵计算,得到不同head下的 Q i , K i , V i , i ∈ { 1 , ⋯ , n h } Q_i,K_i,V_i,i\in \{1,\cdots,n_h\} Qi,Ki,Vi,i∈{1,⋯,nh}
q 1 ; i q_{1;i} q1;i表示head i i i query矩阵在序列位置为 1 1 1处的向量,其他符号记法类似。
在生成阶段每一个head经过attention计算后的 t t t位置的输出 y t ; i y_{t;i} yt;i如下( y 1 ; i , y 2 ; i , ⋯ y T ; i y_{1;i},y_{2;i}, \cdots y_{T;i} y1;i,y2;i,⋯yT;i依序生成),
y t ; i = s o f t m a x ( q t ; i [ k 1 ; i T , ⋯ , k t ; i T ] d ) [ v 1 ; i ⋮ v t ; i ] = s o f t m a x ( q t ; i K ≤ t ; i T d ) V ≤ t ; i (5) \begin{aligned} y_{t;i} &= \mathrm{softmax} \bigg ( \frac{q_{t;i}\begin{bmatrix}k_{1;i}^T , \cdots , k_{t;i}^T \end{bmatrix}}{\sqrt{d}} \bigg ) \begin{bmatrix}v_{1;i} \\ \vdots \\ v_{t;i} \end{bmatrix} \\ & = \mathrm{softmax} \bigg ( \frac{q_{t;i} \boxed{K_{\leq t;i}^T} }{\sqrt{d}} \bigg ) \boxed{V_{\leq t;i}} \\ \end{aligned} \tag{5} yt;i=softmax(dqt;i[k1;iT,⋯,kt;iT]) v1;i⋮vt;i =softmax(dqt;iK≤t;iT)V≤t;i(5)
循环Attention head i = 1 , ⋯ , n h i=1,\cdots,n_h i=1,⋯,nh ,可以计算所有head t t t时刻的输出 { y t ; 1 , ⋯ , y t ; n h } \{y_{t;1}, \cdots ,y_{t;n_h}\} {yt;1,⋯,yt;nh}最后将不同head的输出 y t ; i y_{t;i} yt;i进行拼接,得到最终的输出
y t = C a t ( [ y t ; 1 , y t ; 2 , ⋯ , y t ; n h ] , d i m = 1 ) y t ∈ R 1 × d (6) y_{t} = \mathrm{Cat}([y_{t;1}, y_{t;2}, \cdots ,y_{t;n_h}], \mathrm{dim}=1) \\y_t \in \mathbb{R}^{1 \times d} \tag{6} yt=Cat([yt;1,yt;2,⋯,yt;nh],dim=1)yt∈R1×d(6)
对于MHA
而言,
y
t
y_t
yt的计算缓存的kv-cache为
{
K
≤
t
;
i
,
V
≤
t
;
i
∣
i
=
1
,
2
,
⋯
,
n
h
}
\{K_{\leq t;i} ,V_{\leq t;i} |i=1,2,\cdots ,n_h\}
{K≤t;i,V≤t;i∣i=1,2,⋯,nh}
4 Group Query Attention(GQA
) KV-Cache工作机制
paper: https://arxiv.org/pdf/2305.13245
GQA的attention计算机制与MHA一致。有所区别的是,GQA
为了降低KV-Cache的存储,将attention的head分为了
n
g
n_g
ng组,同一组共享kv-cache。
g i = ⌈ ( i n g ) ⌉ (7) g_i = \lceil (\frac{i}{n_g} ) \rceil\tag{7} gi=⌈(ngi)⌉(7)
⌈ ⋅ ⌉ \lceil \cdot \rceil ⌈⋅⌉是向上取整符号。若 n g = 4 n_g = 4 ng=4,那么 i ∈ { 1 , 2 , 3 , 4 } i \in \{ 1,2,3,4 \} i∈{1,2,3,4}共享 g i = 1 g_i = 1 gi=1这个group的key,value。
同样,在生成阶段 y 1 ; i , y 2 ; i , ⋯ y T ; i y_{1;i},y_{2;i}, \cdots y_{T;i} y1;i,y2;i,⋯yT;i依序生成。每一个head经过attention计算后的 t t t位置的输出 y t ; i y_{t;i} yt;i如下,
y t ; i = s o f t m a x ( q t ; i [ k 1 ; g i T , ⋯ , k t ; g i T ] d ) [ v 1 ; g i ⋮ v t ; g i ] = s o f t m a x ( q t ; i K ≤ t ; g i T d ) V ≤ t ; g i (8) \begin{aligned} y_{t;i} &= \mathrm{softmax} \bigg ( \frac{q_{t;i}\begin{bmatrix}k_{1;g_i}^T , \cdots , k_{t;g_i}^T \end{bmatrix}}{\sqrt{d}} \bigg ) \begin{bmatrix}v_{1;g_i} \\ \vdots \\ v_{t;g_i} \end{bmatrix} \\ & = \mathrm{softmax} \bigg ( \frac{q_{t;i} \boxed{K_{\leq t;g_i}^T} }{\sqrt{d}} \bigg ) \boxed{ V_{\leq t;g_i} } \\ \end{aligned} \tag{8} yt;i=softmax(dqt;i[k1;giT,⋯,kt;giT]) v1;gi⋮vt;gi =softmax(dqt;iK≤t;giT)V≤t;gi(8)
Loop Attention head i = 1 , ⋯ , n h i=1,\cdots,n_h i=1,⋯,nh ,可以计算所有head t t t时刻的输出 { y t ; 1 , ⋯ , y t ; n h } \{y_{t;1}, \cdots ,y_{t;n_h}\} {yt;1,⋯,yt;nh}最后将不同head的输出 y t ; i y_{t;i} yt;i进行拼接,得到最终的输出
y t = C a t ( [ y t ; 1 , y t ; 2 , ⋯ , y t ; n h ] , d i m = 1 ) y t ∈ R 1 × d (9) y_{t} = \mathrm{Cat}([y_{t;1}, y_{t;2}, \cdots ,y_{t;n_h}], \mathrm{dim}=1) \\y_t \in \mathbb{R}^{1 \times d} \tag{9} yt=Cat([yt;1,yt;2,⋯,yt;nh],dim=1)yt∈R1×d(9)
对于GQA
而言,
y
t
y_t
yt的计算缓存的kv-cache为
{
K
≤
t
;
g
i
,
V
≤
t
;
g
i
∣
i
=
1
,
2
,
⋯
,
n
g
}
\{K_{\leq t;g_i} ,V_{\leq t;g_i} |i=1,2,\cdots ,n_g\}
{K≤t;gi,V≤t;gi∣i=1,2,⋯,ng},相比标准的MHA,KV-cache降低了
n
h
n
g
\frac{n_h}{n_g}
ngnh倍
5 Multi Query Attention(MQA
) KV-Cache工作机制
paper: https://arxiv.org/pdf/1911.02150
MQA
是GQA
的一个特例。当
n
g
=
1
n_g=1
ng=1时,即所有head的query共享同一组key, value,此时的GQA
成为MQA
。
对于MQA而言,
y
t
y_t
yt的计算缓存的kv-cache为
{
K
≤
t
;
g
i
,
V
≤
t
;
g
i
∣
g
i
=
1
}
\{K_{\leq t;g_i} ,V_{\leq t;g_i} |g_i=1\}
{K≤t;gi,V≤t;gi∣gi=1},相比标准的MHA
,KV-cache降低了
n
h
n_h
nh倍。
6 Multi-Head Latent Attention (MLA)的工作机制
MLA
是deepseek
提出的一项针对kv-cache的优化。
paper: https://arxiv.org/abs/2405.04434
(一)先抛开位置编码
假定MLA的输入为
X
=
[
x
1
;
x
2
;
⋯
;
x
T
]
,
X
∈
R
T
×
d
X=[x_1;x_2;\cdots ;x_T], X\in \mathbb{R} ^ {T\times d}
X=[x1;x2;⋯;xT],X∈RT×d,
T
T
T为sequence的长度。假定有
n
h
n_h
nh个head,每个head投影的维度为
d
n
h
\frac{d}{n_h}
nhd。初步来看未引入位置编码的MLA像是引入了一个低秩分解矩阵(类似LoRA
的做法)的MHA
head i i i的Q,K,V计算过程如下
[ Q i K i V i ] = [ X W q ; i X W k ; i X W v ; i ] ⏟ M H A → [ Q i K i V i ] = [ X W q / A W q / B ; i X W k v / A W k / B ; i X W k v / A W v / B ; i ] = [ C q W q / B ; i C k v W k / B ; i C k v W v / B ; i ] ⏟ M L A (10) \underbrace{ \begin{bmatrix} Q_i \\K_i\\V_i \end{bmatrix} =\begin{bmatrix} XW_{q;i} \\ XW_{k;i}\\ XW_{v;i} \end{bmatrix} }_{MHA} \rightarrow \underbrace{ \begin{bmatrix} Q_i \\K_i\\V_i \end{bmatrix}= \begin{bmatrix} XW_{q/A} W_{q/B;i}\\ XW_{kv/A}W_{k/B;i}\\ XW_{kv/A}W_{v/B;i} \end{bmatrix}= \begin{bmatrix} C^{q}W_{q/B;i}\\ C^{kv}W_{k/B;i}\\ C^{kv}W_{v/B;i} \end{bmatrix} }_{MLA} \tag{10} MHA QiKiVi = XWq;iXWk;iXWv;i →MLA QiKiVi = XWq/AWq/B;iXWkv/AWk/B;iXWkv/AWv/B;i = CqWq/B;iCkvWk/B;iCkvWv/B;i (10)
矩阵维度变换说明
dimension | |
---|---|
W
k
v
/
A
W_{kv/A}
Wkv/A (下标kv表示key-value的compress latent编码投影矩阵,/A 类比LORA的A矩阵) | R d × d c \mathbb{R} ^{d \times d_c} Rd×dc |
W
k
/
B
;
i
,
W
v
/
B
;
i
W_{k/B;i},W_{v/B;i}
Wk/B;i,Wv/B;i (/B 类比LORA的B矩阵, $;i$ )表示attention第
i
i
i个head | R d c × d n h \mathbb{R} ^{d_c \times \frac{d}{n_h}} Rdc×nhd |
W q / A ; i W_{q/A;i} Wq/A;i | R d × d c ′ \mathbb{R} ^{ d \times d'_c} Rd×dc′ |
W q / B ; i W_{q/B;i} Wq/B;i | R d n h × d c \mathbb{R} ^{\frac{d}{n_h} \times d_c} Rnhd×dc |
C q C^q Cq (上标q表示 query的compress latent code) | R T × d c ′ \mathbb{R} ^ {T\times d_c'} RT×dc′ |
C k v C^{kv} Ckv (表示 key-value的compress latent code) | R T × d c \mathbb{R} ^ {T\times d_c} RT×dc |
在生成阶段,每一个head经过attention计算后的 t t t位置的输出 y t ; i y_{t;i} yt;i如下( y 1 ; i , y 2 ; i , ⋯ y T ; i y_{1;i},y_{2;i}, \cdots y_{T;i} y1;i,y2;i,⋯yT;i依序生成)
y t ; i = s o f t m a x ( q t ; i K ≤ t ; i T d ) V ≤ t ; i = s o f t m a x ( q t ; i W k / B ; i T ( C ≤ t k v ) T d ) C ≤ t k v W v / B ; i \begin{align*} y_{t;i} &= \ \mathrm{softmax} \bigg ( \frac{q_{t;i} K^T_{\leq t;i}}{\sqrt{d}} \bigg ) V_{\leq t;i} \tag{9} \\ & = \mathrm{softmax} \bigg ( \frac{q_{t;i} W_{k/B;i}^{T}(\boxed{C^{kv}_{\leq t}})^T}{\sqrt{d}} \bigg ) \boxed{C^{kv}_{\leq t}}W_{v/B;i}\tag{10} \end{align*} yt;i= softmax(dqt;iK≤t;iT)V≤t;i=softmax(dqt;iWk/B;iT(C≤tkv)T)C≤tkvWv/B;i(9)(10)
式(9)和MHA
的generate阶段的形式相同,当然可以通过缓存
{
K
≤
t
;
i
,
V
≤
t
;
i
∣
i
=
1
,
2
,
⋯
,
n
h
}
\{K_{\leq t;i} ,V_{\leq t;i} |i=1,2,\cdots ,n_h\}
{K≤t;i,V≤t;i∣i=1,2,⋯,nh}实现kv-cache。
但MLA提供了一个新的方法(式10),只需要缓存 C ≤ t k v ∈ R t × d c C^{kv}_{\leq t} \in \mathbb{R}^{t \times d_c} C≤tkv∈Rt×dc即可,相比原始方法的kv-cache的存储单元数量从 t × d n h × n h × 2 = 2 d t t\times \frac{d}{n_h} \times n_h \times 2 = 2dt t×nhd×nh×2=2dt降低到 t × d c t\times d_c t×dc。但这个方法需要引入两个矩阵乘法的计算量。因为 d c d_c dc不大,引入的计算量是可以接受的。(还有一种更为巧妙的方法能规避计算量增加的问题,在(二)中介绍)
(二)引入位置编码的MLA
这个形式也是deepseek论文中提出的形式。有了上面的基础,再理解就很简单了。同样假定MLA的输入为 X = [ x 1 ; x 2 ; ⋯ ; x T ] , X ∈ R T × d X=[x_1;x_2;\cdots ;x_T], X\in \mathbb{R} ^ {T\times d} X=[x1;x2;⋯;xT],X∈RT×d, T T T为sequence的长度。假定有 n h n_h nh个head,每个head投影的维度为 d h = d n h d_h=\frac{d}{n_h} dh=nhd
head i i i的Q,K,V计算过程如下
[ Q i K i V i ] ⇒ [ X W q ; i X W k ; i X W v ; i ] ⏟ M H A ⇒ [ X W q / A W q / B ; i X W k v / A W k / B ; i X W k v / A W v / B ; i ] = [ C q W q / B ; i C k v W k / B ; i C k v W v / B ; i ] ⏟ MLA, no pos embedding ⇒ [ Q i C , Q i R K i C , K R V i ] = [ C q W q C / B ; i , R O P E ( C q W q R / B ; i ) C k v W k C / B ; i , R O P E ( X W k R ) C k v W v / B ; i ] ⏟ MLA with pos embedding (11) \begin{aligned} \begin{bmatrix} Q_i \\K_i\\V_i \end{bmatrix} & \Rightarrow \underbrace{\begin{bmatrix} XW_{q;i} \\ XW_{k;i}\\ XW_{v;i} \end{bmatrix} }_{MHA} \Rightarrow \underbrace{ \begin{bmatrix} XW_{q/A} W_{q/B;i}\\ XW_{kv/A}W_{k/B;i}\\ XW_{kv/A}W_{v/B;i} \end{bmatrix} = \begin{bmatrix} C^{q}W_{q/B;i}\\ C^{kv}W_{k/B;i}\\ C^{kv}W_{v/B;i} \end{bmatrix} }_{\text{MLA, no pos embedding}} \\ & \Rightarrow \underbrace{ \begin{bmatrix} Q_i^C,Q_i^R \\K_i^C, K^R\\V_i \end{bmatrix} = \begin{bmatrix} C^qW_{q^C/B;i},\mathrm{ROPE} (C^qW_{q^R/B;i}) \\ C^{kv}W_{k^C/B;i}, \mathrm{ROPE} (XW_{k^R})\\ C^{kv}W_{v/B;i} \end{bmatrix} }_{\text{MLA with pos embedding}} \end{aligned} \tag{11} QiKiVi ⇒MHA XWq;iXWk;iXWv;i ⇒MLA, no pos embedding XWq/AWq/B;iXWkv/AWk/B;iXWkv/AWv/B;i = CqWq/B;iCkvWk/B;iCkvWv/B;i ⇒MLA with pos embedding QiC,QiRKiC,KRVi = CqWqC/B;i,ROPE(CqWqR/B;i)CkvWkC/B;i,ROPE(XWkR)CkvWv/B;i (11)
矩阵维度变换说明
dimension | |
---|---|
W
k
v
/
A
W_{kv/A}
Wkv/A(下标kv表示key-value的compress latent编码投影矩阵,/A 类比LORA的A矩阵) | R d × d c \mathbb{R} ^{d \times d_c} Rd×dc |
W
k
C
/
B
;
i
,
W
v
/
B
;
i
W_{k^C/B;i},W_{v/B;i}
WkC/B;i,Wv/B;i (/B 类比LORA的B矩阵, $;i$ )表示attention第
i
i
i个head | R d c × d h \mathbb{R} ^{d_c \times d_h} Rdc×dh |
W k R W_{k^R} WkR | R d × d h R \mathbb{R} ^{d \times d_h^R} Rd×dhR |
W q / A ; i W_{q/A;i} Wq/A;i | R d × d c ′ \mathbb{R} ^{d \times d'_c} Rd×dc′ |
W q C / B ; i W_{q^C/B;i} WqC/B;i | R d c ′ × d h \mathbb{R} ^{d'_c \times d_h} Rdc′×dh |
W q R / B ; i W_{q^R/B;i} WqR/B;i | R d c ′ × d h R \mathbb{R} ^{d_c' \times d^R_h } Rdc′×dhR |
C q C^q Cq (上标q表示 query的compress latent code) | R T × d c ′ \mathbb{R} ^ {T\times d_c'} RT×dc′ |
C k v C^{kv} Ckv(表示 key-value的compress latent code) | R T × d c \mathbb{R} ^ {T\times d_c} RT×dc |
Q i C Q_{i}^{C} QiC (上标C表示compression的首字母“C”) | R T × d h \mathbb{R}^{T\times d_h} RT×dh (不含位置编码的query) |
Q
i
R
Q_{i}^{R}
QiR (上标R是RoPE的R ) | R T × d h R \mathbb{R}^{T\times d_h^R} RT×dhR (包含位置编码的query) |
y t ; i = s o f t m a x ( [ q t ; i C , q t ; i R ] [ ( k 1 ; i C ) T , ⋯ , ( k t ; i C ) T ( k 1 R ) T , ⋯ , ( k t R ) T ] d h + d h R ) V ≤ t ; i = s o f t m a x ( [ q t ; i C , q t ; i R ] [ ( K ≤ t ; i C ) T ( K ≤ t R ) T ] d h + d h R ) V ≤ t ; i = s o f t m a x ( [ q t ; i C , q t ; i R ] [ ( C ≤ t k v W k C / B ; i ) T ( K ≤ t R ) T ] d h + d h R ) C ≤ t k v W v / B ; i \begin{align*} y_{t;i} &= \mathrm{softmax} \bigg ( \frac{ \begin{bmatrix} q^C_{t;i}, q^R_{t;i} \end{bmatrix} \begin{bmatrix} (k^C_{1;_i})^T , \cdots , (k^C_{t;i})^T \\ (k^R_{1})^T , \cdots , (k^R_{t})^T \end{bmatrix} }{\sqrt{d_h+d_h^R}} \bigg ) V_{\leq t;i} \tag{12} \\ &= \mathrm{softmax} \bigg ( \frac{ \begin{bmatrix} q^C_{t;i}, q^R_{t;i} \end{bmatrix} \begin{bmatrix} (K^C_{\leq t;i})^T \\ (K^R_{\leq t})^T \end{bmatrix} }{\sqrt{d_h+d_h^R}} \bigg ) V_{\leq t;i} \tag{13} \\ & = \mathrm{softmax} \bigg ( \frac{ \begin{bmatrix} q^C_{t;i}, q^R_{t;i} \end{bmatrix} \begin{bmatrix} ( \boxed{ C^{kv}_{\le t}} W_{k^C/B;i})^T \\ (\boxed{ K^R_{\leq t}})^T \end{bmatrix} }{\sqrt{d_h+d_h^R}} \bigg ) \boxed {C^{kv}_{\leq t} } W_{v/B;i}\tag{14} \end{align*} yt;i=softmax(dh+dhR[qt;iC,qt;iR][(k1;iC)T,⋯,(kt;iC)T(k1R)T,⋯,(ktR)T])V≤t;i=softmax(dh+dhR[qt;iC,qt;iR][(K≤t;iC)T(K≤tR)T])V≤t;i=softmax(dh+dhR[qt;iC,qt;iR] (C≤tkvWkC/B;i)T(K≤tR)T )C≤tkvWv/B;i(12)(13)(14)
从式14可见,加了位置编码的MLA相比无位置编码的情形多缓存了一个
K
≤
t
R
K^R_{\leq t}
K≤tR。这里需要注意,对于所有attention head
K
≤
t
R
K^R_{\leq t}
K≤tR是共享的(类似MQA
)。
此时KV-cache需要缓存 { C ≤ t k v , K ≤ t R } \{C^{kv}_{\leq t}, K^R_{\leq t}\} {C≤tkv,K≤tR} 的存储单元数量为 t ( d c + d h R ) t (d_c + d_h^R) t(dc+dhR)
**与不加位置编码的情形一致,这个方法推理时需要引入两个矩阵乘法的计算量,
C
≤
t
k
v
W
k
C
/
B
;
i
C^{kv}_{\le t} W_{k^C/B;i}
C≤tkvWkC/B;i和
C
≤
t
k
v
W
v
/
B
;
i
C^{kv}_{\leq t} W_{v/B;i}
C≤tkvWv/B;i。**不妨对式(14)再次进行变形,可以看到MLA
又一巧妙的设计
y t ; i = s o f t m a x ( [ q t ; i C , q t ; i R ] [ ( C ≤ t k v W k C / B ; i ) T ( K ≤ t R ) T ] d h + d h R ) C ≤ t k v W v / B ; i = s o f t m a x ( q t ; i C ( W k C / B ; i ) T ( C ≤ t k v ) T + q t ; i R ( K ≤ t R ) T d h + d h R ) C ≤ t k v W v / B ; i \begin{align*} y_{t;i} &= \mathrm{softmax} \bigg ( \frac{ \begin{bmatrix} q^C_{t;i}, q^R_{t;i} \end{bmatrix} \begin{bmatrix} ( \boxed{ C^{kv}_{\le t}} W_{k^C/B;i})^T \\ (\boxed{ K^R_{\leq t}})^T \end{bmatrix} }{\sqrt{d_h+d_h^R}} \bigg ) \boxed {C^{kv}_{\leq t} } W_{v/B;i}\\ &= \mathrm{softmax} \bigg ( \frac{ q^C_{t;i}(W_{k^C/B;i})^T(\boxed{ C^{kv}_{\le t}})^T+ q^R_{t;i} ( \boxed{K^R_{\leq t}})^T }{\sqrt{d_h+d_h^R}} \bigg ) \boxed {C^{kv}_{\leq t} } W_{v/B;i} \tag{15}\\ \end{align*} yt;i=softmax(dh+dhR[qt;iC,qt;iR] (C≤tkvWkC/B;i)T(K≤tR)T )C≤tkvWv/B;i=softmax(dh+dhRqt;iC(WkC/B;i)T(C≤tkv)T+qt;iR(K≤tR)T)C≤tkvWv/B;i(15)
从式(15)可见,在推理时, W q C / B ; i ( W k C / B ; i ) T W_{q^C/B;i}(W_{k^C/B;i})^T WqC/B;i(WkC/B;i)T可以预先合并为1个矩阵,同理 W v / B ; i W_{v/B;i} Wv/B;i 也可以和随后的线性层的权重进行合并。但计算量的降低主要与矩阵乘法的计算顺序改变导致:
计算次序 | element-level乘法执行次数 |
---|---|
q t ; i C ⏟ 2 ◯ ( C ≤ t k v W k C / B ; i ) T ⏟ 1 ◯ \underbrace{q^C_{t;i}}_{\textcircled{2}}\underbrace{( C^{kv}_{\le t} W_{k^C/B;i})^T}_{\textcircled{1}} 2◯ qt;iC1◯ (C≤tkvWkC/B;i)T | t × d c × d h ⏟ 1 ◯ + t × d h × 1 ⏟ 2 ◯ = ( d c + 1 ) t d h \underbrace{t \times d_c \times d_h}_{\textcircled{1}} + \underbrace{t \times d_h \times 1}_{\textcircled{2}} = (d_c + 1)td_h 1◯ t×dc×dh+2◯ t×dh×1=(dc+1)tdh |
q t ; i C ( W k C / B ; i ) T ⏟ 1 ◯ ( C ≤ t k v ) T ⏟ 2 ◯ \underbrace{q^C_{t;i}(W_{k^C/B;i})^T}_{\textcircled{1}}\underbrace{( C^{kv}_{\le t})^T }_{\textcircled{2}} 1◯ qt;iC(WkC/B;i)T2◯ (C≤tkv)T | 1 × d h × d c ⏟ 1 ◯ + 1 × d c × t ⏟ 2 ◯ = ( d h + t ) d c \underbrace{1 \times d_h \times d_c}_{\textcircled{1}} + \underbrace{1 \times d_c \times t}_{\textcircled{2}} = (d_h + t)d_c 1◯ 1×dh×dc+2◯ 1×dc×t=(dh+t)dc |
维度说明: q t ; i C ∈ R 1 × d h , C ≤ t k v ∈ R t × d c , W k C / B ; i ∈ R d c × d h q^C_{t;i} \in \mathbb{R}^{1 \times d_h}, C^{kv}_{\leq t} \in \mathbb{R}^ {t \times d_c}, W_{k^C/B;i} \in \mathbb{R}^{d_c \times d_h} qt;iC∈R1×dh,C≤tkv∈Rt×dc,WkC/B;i∈Rdc×dh
7 小结
文本详细介绍了kv-cache原理,及降低kv-cache存储成本目前常用的MQA,GQA,MLA方法。如有疏漏之处,敬请指出。
不同attention方法 KV cache的存储单元数量
KV-cache | 存储单元数量 | |
---|---|---|
Casual Attention | K ≤ t , V ≤ t K_{\leq t},V_{\leq t} K≤t,V≤t | 2 t d 2td 2td |
MHA | { K ≤ t ; i , V ≤ t ; i ∣ i = 1 , 2 , ⋯ , n h } \{K_{\leq t;i} ,V_{\leq t;i} |i=1,2,\cdots ,n_h\} {K≤t;i,V≤t;i∣i=1,2,⋯,nh} | 2 t n h d h 2 t n_hd_h 2tnhdh |
GQA | { K ≤ t ; g i , V ≤ t ; g i ∣ i = 1 , 2 , ⋯ , n g } \{K_{\leq t;g_i} ,V_{\leq t;g_i} |i=1,2,\cdots ,n_g\} {K≤t;gi,V≤t;gi∣i=1,2,⋯,ng} | 2 t n g d h 2 t n_gd_h 2tngdh |
MQA | { K ≤ t ; g i , V ≤ t ; g i ∣ g i = 1 } \{K_{\leq t;g_i} ,V_{\leq t;g_i} |g_i=1\} {K≤t;gi,V≤t;gi∣gi=1} | 2 t d h 2 t d_h 2tdh |
MLA | C ≤ t k v , K ≤ t R C^{kv}_{\leq t} ,K^R_{\leq t} C≤tkv,K≤tR | t ( d c + d h R ) t (d_c + d_h^R) t(dc+dhR) |