一、 引言
大型语言模型(LLMs),如 OpenAI 的 GPT 系列 [1] 和 Meta 的 LLaMA [2] 等,基于 Transformer 架构 [3] 在自然语言处理领域取得了突破性进展。这些模型通常拥有数十亿乃至数千亿的参数,其推理过程,尤其是在生成长文本时,面临着显著的计算和内存效率挑战。本文旨在深入剖析提升 LLM 推理效率的关键技术,聚焦于 KV 缓存(Key-Value Cache)、预填充(Prefill)和解码(Decode),并着重阐释 KV 缓存如何通过优化 Transformer 模型中的注意力计算,成为加速推理的关键。
二、 大型语言模型的推理过程概览
LLMs 通常采用自回归方式生成文本 [1]。给定一个输入提示(Prompt),模型会迭代地预测序列中的下一个 token,并将已生成的 token 纳入上下文以预测后续内容. 在每一步,模型执行前向传播,计算最可能的下一个 token。对于长序列,这种逐token生成模式导致先前计算的重复,成为效率瓶颈。
图片来源:https://jalammar.github.io/illustrated-gpt2/
三、 什么是 KV 缓存?
3.1 Transformer 模型中的注意力机制与 KV 缓存的诞生
在 Transformer 模型的每一层中,自注意力机制通过计算输入序列中不同位置之间的相关性来捕获上下文信息 [3]。对于输入序列的每个 token,模型会生成三个向量:Query (Q)、Key (K) 和 Value (V)。注意力得分通过 Query 和所有 Key 的点积计算,然后通过 Softmax 函数进行归一化,得到每个 Key 对应的注意力权重。最终,每个 token 的注意力输出是 Value 向量的加权和。数学上,对于一个包含
N
N
N 个 token 的序列,Query、Key 和 Value 的维度通常为
d
k
d_k
dk 和
d
v
d_v
dv。注意力计算可以表示为:
图片来源:https://jalammar.github.io/illustrated-gpt2/
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dkQKT)V
其中 Q , K , V ∈ R N × d k Q, K, V \in \mathbb{R}^{N \times d_k} Q,K,V∈RN×dk, K T K^T KT 是 Key 矩阵的转置, d k \sqrt{d_k} dk 是缩放因子。
3.2 KV 缓存的机制
在自回归生成过程中,对于已经生成的 token,它们的 Key 和 Value 向量在后续生成新的 token 时会被多次用到。KV 缓存是一种内存优化技术,用于存储在预填充阶段和先前解码步骤中计算得到的 Key 和 Value 矩阵。根据研究 [4],对于 Transformer 模型的每一层和每一个注意力头,都会维护一个独立的 KV 缓存,其维度通常为 [batch_size, num_heads, seq_len, head_dim]
。
图片来源:https://jalammar.github.io/illustrated-gpt2/
四、 预填充(Prefill)阶段
4.1 预填充的定义与过程
预填充是推理的首个阶段,当模型接收到长度为
L
L
L 的输入提示时启动。模型对整个输入提示执行一次完整的前向传播。在这个过程中,对于 Transformer 模型的每一层和每一个注意力头,模型会为每个输入 token 计算出对应的 Query、Key 和 Value 向量。关键在于,预填充阶段会将计算得到的 所有 输入 token 的 Key 和 Value 向量存储到 KV 缓存中。具体来说,输入 token 序列首先通过 Embedding 层转换为 Embedding 向量。然后,这些向量依次通过 Transformer 的各个层。在每一层的自注意力模块中,对于输入序列的每个位置
i
∈
{
1
,
.
.
.
,
L
}
i \in \{1, ..., L\}
i∈{1,...,L},模型计算
Q
i
,
K
i
,
V
i
Q_i, K_i, V_i
Qi,Ki,Vi。Key 向量
{
K
1
,
.
.
.
,
K
L
}
\{K_1, ..., K_L\}
{K1,...,KL} 组成 Key 矩阵
K
p
r
e
f
i
l
l
∈
R
L
×
d
k
K_{prefill} \in \mathbb{R}^{L \times d_k}
Kprefill∈RL×dk,Value 向量
{
V
1
,
.
.
.
,
V
L
}
\{V_1, ..., V_L\}
{V1,...,VL} 组成 Value 矩阵
V
p
r
e
f
i
l
l
∈
R
L
×
d
v
V_{prefill} \in \mathbb{R}^{L \times d_v}
Vprefill∈RL×dv,并存储在 KV 缓存中。此过程对整个输入序列并行进行,为后续的解码阶段准备了初始的上下文信息。预填充阶段的输出是填充好的 KV 缓存以及最后一个输入 token 的隐藏状态。
4.2 预填充阶段的变化
在预填充阶段,KV 缓存从空状态开始,逐步填充输入提示中每个 token 的 Key 和 Value 向量。在完成预填充后,KV 缓存包含了输入序列的完整上下文表示,供后续的解码阶段使用。
五、 解码(Decode)阶段
5.1 解码的定义与过程
解码是模型在预填充之后,逐个生成输出 token 的阶段。假设模型已经生成了 t − 1 t-1 t−1 个 token,KV 缓存中已经存储了这 t − 1 t-1 t−1 个 token(以及预填充的输入 token)的 Key 和 Value 向量。解码阶段的目标是生成第 t t t 个 token。在每一步 t t t,模型以前一步生成的 token 的 Embedding 作为输入,通过 Transformer 的各个层,计算得到当前 token 的 Query 向量 Q t ∈ R 1 × d k Q_t \in \mathbb{R}^{1 \times d_k} Qt∈R1×dk。关键在于,模型不再需要重新计算先前 L + t − 1 L+t-1 L+t−1 个 token 的 Key 和 Value 向量。它直接使用 KV 缓存中已存储的 Key 矩阵 K c a c h e d ∈ R ( L + t − 1 ) × d k K_{cached} \in \mathbb{R}^{(L+t-1) \times d_k} Kcached∈R(L+t−1)×dk 和 Value 矩阵 V c a c h e d ∈ R ( L + t − 1 ) × d v V_{cached} \in \mathbb{R}^{(L+t-1) \times d_v} Vcached∈R(L+t−1)×dv。模型将当前 token 的 Query 向量 Q t Q_t Qt 与 K c a c h e d K_{cached} Kcached 进行注意力计算:
Attention ( Q t , K c a c h e d , V c a c h e d ) = softmax ( Q t K c a c h e d T d k ) V c a c h e d \text{Attention}(Q_t, K_{cached}, V_{cached}) = \text{softmax}(\frac{Q_t K_{cached}^T}{\sqrt{d_k}})V_{cached} Attention(Qt,Kcached,Vcached)=softmax(dkQtKcachedT)Vcached
得到当前 token 的注意力输出,并预测下一个 token。一旦第 t t t 个 token 生成,其对应的 Key 向量 K t K_t Kt 和 Value 向量 V t V_t Vt 会被计算出来,并分别追加到 KV 缓存中的 K c a c h e d K_{cached} Kcached 和 V c a c h e d V_{cached} Vcached 矩阵中,为下一步生成第 t + 1 t+1 t+1 个 token 做准备。
5.2 解码阶段的变化
在解码阶段,KV 缓存随着每个新生成 token 的出现而动态增长。每生成一个 token,其对应的 Key 和 Value 向量就会被添加到 KV 缓存的末尾。这样,在后续的解码步骤中,模型可以利用所有先前生成 token 的上下文信息,而无需重新计算。
六、 KV 缓存的优化策略
为了进一步提高推理效率和处理更长上下文,研究人员提出了多种 KV 缓存的优化策略:
- 分块 KV 缓存(PagedAttention) [5]:传统的 KV 缓存为每个推理请求分配连续的内存空间,可能导致内存碎片和浪费,尤其是在处理不同长度的序列时。PagedAttention 将 KV 缓存分割成固定大小的内存块(pages),并动态地为每个请求分配和管理这些块,类似于操作系统的内存分页机制。逻辑上的 KV 缓存被映射到不连续的物理内存块上,从而更有效地利用内存,减少碎片,并支持更长的上下文。
- KV 缓存压缩 [6, 7]:随着生成序列的增长,KV 缓存会消耗大量内存,成为处理超长序列的瓶颈。KV 缓存压缩技术旨在减小 KV 缓存的内存占用。常见的技术包括:
- 量化 (Quantization):将 Key 和 Value 向量从高精度浮点数(如 FP16 或 FP32)转换为低精度整数(如 INT8 或 INT4),从而减少每个元素的内存占用。
- 剪枝 (Pruning):移除 KV 缓存中不重要的元素,例如基于注意力权重或其他重要性指标进行筛选。
- 知识蒸馏 (Knowledge Distillation):训练一个更小的模型来近似存储在 KV 缓存中的信息。
- 选择性 KV 缓存 [8]:并非所有历史 token 的信息对于生成当前 token 都同等重要。选择性 KV 缓存策略旨在识别并保留对当前生成最重要的 Key 和 Value 信息,而丢弃或降级不重要的信息,以减少内存和计算开销。这可以通过分析注意力权重、学习 token 的重要性评分等方法实现。例如,一些研究探索了基于滑动窗口或者衰减机制来管理 KV 缓存。
七、 结论
KV 缓存是现代大型语言模型推理优化的关键技术。通过在预填充阶段存储输入提示的 Key 和 Value 向量,并在解码阶段重复使用和增量更新这些信息,模型避免了大量的冗余计算,从而实现了更快的推理速度和更低的资源消耗,尤其是在生成长序列时。深入理解 KV 缓存、预填充和解码的工作原理,以及相关的优化策略,对于进一步研究和应用大型语言模型至关重要。
内容同步在我的微信公众号: 智语Bot
八、 参考文献
-
[1] Brown, T. B., Mann, B., Ryder, N., Subbiah, M., Kaplan, J., Dhariwal, P., … & Amodei, D. (2020). Language models are few-shot learners.
-
[2] Touvron, H., Lavril, T., Izacard, G., Martinet, X., Lachaux, M. A., Lacroix, T., … & Lample, G. (2023). Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288.
-
[3] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., … & Polosukhin, I. (2017). Attention is all you need.
-
[4] Gehrmann, S., Dai, Y., Pilehvar, M. T., & Gurevych, I. (2021). Efficiently Modeling Long Sequences with Structured State Spaces. arXiv preprint arXiv:2111.00043. (虽然不是直接关于 KV 缓存,但讨论了长序列建模的挑战,KV 缓存是解决这些挑战的关键)
-
[6] Frantar, E., Alistarh, D., & Hoefler, T. (2023). SpQR: A Sparse-Quantized Representation for Near-Lossless Compression of Large Transformer Models. *arXiv preprint arXiv:2303.15698
-
[7] Bondarenko, I., কেটে, T., & 팀, D. (2021). Understanding the Trade-offs in Training Quantized Neural Networks. arXiv preprint arXiv:2103.02288. (讨论了量化对神经网络的影响)
-
[8] Geiping, J., Bauchwitz, M., Jin, H., & Goldstein, T. (2024). Retentive Key-Value Memory for Long-Form Sequence Modeling.