目前各类大模型都支持长文本,例如 kimi chat 以及 gemini pro,都支持 100K 以及更高的上下文长度。但越长的上下文,在推理过程中需要存储的 kv cache 也越多。假设,数据的批次用 b 表示,输入序列的长度仍然用 s 表示,输出序列的长度用 n 表示,隐藏层维度用 h 表示,层数用 l 表示。kv cache 的峰值显存占用大小 = b ∗ ( s + n ) ∗ h ∗ l ∗ 2 ∗ 2 = 4 b l h ( s + n ) b * (s + n) * h * l * 2 * 2 = 4blh(s + n) b∗(s+n)∗h∗l∗2∗2=4blh(s+n),这里的第一个 2 表示 k 和 v cache,第二个 2 表示 float16 数据格式存储 kv cache,每个元素占 2 bytes。
然而,目前的大多数 LLM 会使用 GQA 而非 MHA,因此 kv cache 的占用量会更少,以 transformers 的 modeling_llama.py
脚本中的实现为例:
class LlamaAttention(nn.Module):
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
super().__init__()
# ...
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
# ...
def forward(#...) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# ...
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# ...
其中,q_len
= s + n,bsz
= b,self.hidden_size
= h,然而,self.num_key_value_heads
会小于 self.num_heads
,以 Llama3-8B 为例:
"hidden_size": 4096,
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
k 和 v 的注意力头是 q 的 1/4,因此 kv cache 的峰值显存占用大小还可以继续除以 4,在这里暂时表示为 b l h ( s + n ) blh(s + n) blh(s+n)(注意,不同模型的比例不同,需要根据情况调整计算公式)。
示例:我们继续以 Llama3-8B 为例,来计算不同长度时的 kv cache 显存占用。令 b = 1,n = 32。
- s = 512: 32 × 4096 × ( 512 + 32 ) = 71 , 303 , 168 ≈ 0.066 G B 32 \times 4096 \times (512 + 32) = 71,303,168 \approx 0.066GB 32×4096×(512+32)=71,303,168≈0.066GB。
- s = 16,384: 32 × 4096 × ( 1024 + 32 ) = 71 , 303 , 168 ≈ 2.004 G B 32 \times 4096 \times (1024+ 32) = 71,303,168 \approx 2.004GB 32×4096×(1024+32)=71,303,168≈2.004GB。
- s = 327,680: 32 × 4096 × ( 1024 + 32 ) = 71 , 303 , 168 ≈ 40.004 G B 32 \times 4096 \times (1024+ 32) = 71,303,168 \approx 40.004GB 32×4096×(1024+32)=71,303,168≈40.004GB。
可以看到,随着 context 长度的增加,kv cache 的显存占用量也随之呈线性增长,成为推理的主要瓶颈。在论文《Sequence can Secretly Tell You What to Discard》中,作者介绍了一种优化 KV 缓存的新方法,它能显著减少 KV 缓存的内存占用。通过综合研究,发现在 LLaMA2 系列模型上:
- 相邻 token 的 query 向量之间的相似度非常高;
- 当前 query 的注意力计算可以完全依赖于一小部分前面 query 的注意力信息。
基于这些观察结果,作者提出了一种 KV 缓存驱逐策略 CORM,它能在不对模型进行微调的情况下动态保留重要的 kv 对进行推理。
观察实验与结果
LLMs 中的注意力稀疏性
首先探讨 LLM 注意力层的稀疏性,这是减少 KV 缓存大小的有效前提和依据。具体来说,用重要 key 的比例来表示注意力稀疏性。让 q t ∈ R 1 × d q_t \in \R^{1 \times d} qt∈R1×d 表示第 t 步的 query state 向量, k i ∈ R 1 × d k_i \in \R^{1 \times d} <