生成1个token,需要多少KV Cache开销?

部署运行你感兴趣的模型镜像

引言

本文将对比使用MHA(Multi-Head Attention)、MQA(Multi-Query Attention)、GQA(Grouped-Query Attention)和MLA(Multi-Head Latent Attention)这4种注意力机制时,在decoder阶段使用KV cache生成单个token所需的额外缓存空间。

假设

在具有L层、 n h n_h nh个注意头和key维度 d h d_h dh(即单个head的维度)的Transformer模型中,decoder阶段使用KV cache的话,生成一个token需要多少KV cache空间?

模型中有如下关系:d_model=hidden_size=
embed_dim=d_h * n_h

简要介绍各自注意力机制

MHA、MQA、GQA和MLA这几种都是Transformer架构中注意力机制的不同变体,主要区别在于如何处理键值对。

MHA (Multi-Head Attention):

标准的多头注意力机制,每个注意力头都有独立的查询(Q)、键(K)、值(V)矩阵。计算复杂度高但表达能力强,是原始Transformer使用的方法。

MQA (Multi-Query Attention):

多个查询头共享同一组键值对,即只有一个K和V矩阵,但有多个Q矩阵。这大幅减少了KV缓存的内存占用,提高了推理速度,但可能会损失一些表达能力。

GQA (Grouped-Query Attention):

MHA和MQA的折中方案,将查询头分成若干组,每组内的头共享同一组键值对。比如8个查询头可以分成2组,每组4个头共享KV。在保持较好性能的同时减少内存使用。

MLA (Multi-Head Latent Attention):

DeepSeek V2中引入的注意力机制,通过引入潜在空间来进一步优化计算效率。将高维的键值投影到低维潜在空间进行计算,然后再投影回原空间,在保持性能的同时显著降低计算和存储开销。

这些变体的发展趋势是在保持模型性能的前提下,不断优化计算效率和内存使用,特别是在大模型推理场景中越来越重要。

MHA

KV cache的缓存空间计算如下:

对于每个注意力头,需要存储:

  • Key向量: d h d_h dh 维度
  • Value向量: d h d_h dh 维度

因此每个头需要存储 2 × d h 2 × d_h 2×dh 个数值。

对于整个模型:

  • 每层有 n h n_h nh个注意力头
  • 总共有L层
  • 每个位置需要存储:L × n h n_h nh × 2 × d h d_h dh 个数值

如果序列长度为n,那么KV cache的总存储空间为:
n × L × n h × 2 × d h n × L × n_h × 2 × d_h n×L×nh×2×dh

以浮点数精度计算存储大小:

  • 如果使用FP32(4字节): 4 × n × L × n h × 2 × d h 4 × n × L × n_h × 2 × d_h 4×n×L×nh×2×dh 字节
  • 如果使用FP16(2字节): 2 × n × L × n h × 2 × d h 2 × n × L × n_h × 2 × d_h 2×n×L×nh×2×dh 字节

这个缓存空间会随着序列长度n线性增长,这也是为什么长序列推理时内存消耗会快速增加的原因。在实际部署中,KV cache往往是推理时的内存瓶颈之一。

至于,MQA、GQA和MLA以下直接进行换算。汇总总结部分见微信公众号"小窗幽记机器学习":生成1个token,需要多少KV Cache开销?

您可能感兴趣的与本文相关的镜像

Seed-Coder-8B-Base

Seed-Coder-8B-Base

文本生成
Seed-Coder

Seed-Coder是一个功能强大、透明、参数高效的 8B 级开源代码模型系列,包括基础变体、指导变体和推理变体,由字节团队开源

在 `vllm serve` 推理中,**配置上下文长度**和**最小化 KV 缓存使用**需通过参数组合和模型优化实现,以下是具体方法及原理: --- ### **一、配置上下文长度** #### **1. 核心参数:`--max-model-len`** - **作用**:直接限制模型能处理的最大输入+输出序列长度(单位:token)。 - **配置方法**: ```bash vllm serve --model-path ./model \ --max-model-len 4096 # 示例:限制上下文为4096 token ``` - **注意事项**: - 若输入超过此值,会触发截断或报错(取决于模型实现)。 - 需与模型原始架构支持的上下文长度匹配(如 Llama-2 默认 4096)。 #### **2. 动态调整策略** - **按请求动态设置**: 通过 API 请求的 `max_tokens` 参数动态控制生成长度,避免固定长度的显存浪费: ```python import requests data = { "prompt": "Hello, world!", "max_tokens": 100, # 仅生成100 token,减少KV缓存占用 "temperature": 0.7 } response = requests.post("http://localhost:8000/generate", json=data) ``` --- ### **二、最小化 KV 缓存使用** #### **1. 关键参数优化** ##### **(1) `--block-size`** - **作用**:控制 KV 缓存的分块大小(默认 32),较小的块能减少碎片但增加元数据开销- **优化建议**: - **短序列任务**:设为 `16` 或 `32`(平衡显存和性能)。 - **长序列任务**:设为 `64` 或 `128`(减少块数量,降低元数据开销)。 ```bash vllm serve --block-size 16 # 适合短上下文 ``` ##### **(2) `--disable-kv-cache`** - **作用**:完全禁用 KV 缓存(实验性功能),每次生成重新计算注意力,**显著降低显存但大幅增加计算量**。 - **适用场景**: - 极短序列(如 <512 token)。 - 显存极度受限(如单卡 8GB)。 - **配置方法**: ```bash vllm serve --disable-kv-cache # 慎用!性能可能下降50%+ ``` ##### **(3) `--max-num-batched-tokens`** - **作用**:限制批量处理的最大 token 数,间接控制 KV 缓存的峰值占用- **优化建议**: - 设为略高于典型请求长度(如 `512`)。 ```bash vllm serve --max-num-batched-tokens 512 ``` #### **2. 模型与架构优化** ##### **(1) 使用滑动窗口注意力(Sliding Window Attention)** - **原理**:仅缓存最近 `window_size` 个 tokenKV 数据,而非整个上下文。 - **实现方法**: - 修改模型代码,替换标准注意力为滑动窗口版本(需支持自定义模型)。 - 示例(伪代码): ```python from vllm.model_executor.layers.attention import SlidingWindowAttention # 在模型定义中替换原有Attention层 ``` ##### **(2) 量化激活值** - **作用**:将激活值从 FP16/BF16 量化为 int8,减少 KV 缓存的显存占用- **配置方法**: ```bash vllm serve --dtype int8 # 需模型支持int8激活 ``` ##### **(3) 卸载非关键层到 CPU** - **适用场景**:KV 缓存仍不足时,将 Embedding 或 Projection 层卸载到 CPU。 - **实现方法**: - 修改模型代码,使用 `torch.nn.DataParallel` 或自定义设备映射。 - 示例: ```python model.embedding.to("cpu") # 将Embedding层移至CPU ``` #### **3. 系统级优化** ##### **(1) 启用显存碎片整理** - **作用**:减少 KV 缓存分配时的碎片化。 - **配置方法**: ```bash export CUDA_LAUNCH_BLOCKING=1 # 可能影响性能,需测试 ``` ##### **(2) 使用 `CUDA_CACHE_DISABLE=1`** - **作用**:禁用 CUDA 缓存,避免其占用显存。 - **配置方法**: ```bash CUDA_CACHE_DISABLE=1 vllm serve ... ``` --- ### **三、完整配置示例** ```bash vllm serve \ --model-path ./32B-int8-model \ --max-model-len 2048 \ # 限制上下文长度 --block-size 16 \ # 优化KV缓存分块 --disable-kv-cache false \ # 默认启用KV缓存(禁用需谨慎) --max-num-batched-tokens 512 \ # 控制批量处理大小 --tensor-parallel-size 2 # 多卡并行 ``` --- ### **四、验证与监控** #### **1. 监控 KV 缓存使用** - **通过日志**: ```bash vllm serve --log-interval 10 # 每10秒打印显存和KV缓存统计 ``` - **通过 `nvidia-smi`**: ```bash watch -n 1 nvidia-smi # 实时监控显存占用 ``` #### **2. 压力测试** ```python import requests # 测试接近上下文长度的请求 prompt = "A" * 2000 # 2000 token输入 data = {"prompt": prompt, "max_tokens": 100} response = requests.post("http://localhost:8000/generate", json=data) ``` --- ### **五、常见问题与解决方案** | **问题** | **解决方案** | |------------------------------|-----------------------------------------------------------------------------| | KV 缓存仍超出显存 | 降低 `--max-model-len` 或启用 `--disable-kv-cache`(牺牲性能) | | 生成结果质量下降 | 检查是否因 `--block-size` 过大导致注意力分散,适当减小 | | 批量请求延迟高 | 降低 `--max-num-batched-tokens` 或增加 `--tensor-parallel-size` | ---
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值