文章目录
在大型语言模型(LLM)的推理过程中,除了 Prefill(预填充)阶段,还包含多个关键阶段,每个阶段对性能和资源消耗有不同的影响。以下是完整的阶段划分及其作用:
1. 完整推理阶段划分
阶段 | 触发时机 | 核心任务 | 计算特点 | 硬件影响 |
---|---|---|---|---|
Prefill | 接收用户输入后首轮计算 | 计算输入prompt的KV缓存 | 高并行,长序列O(N²) | GPU计算单元饱和 |
Decoding | 生成每个输出token时 | 自回归生成+更新KV缓存 | 串行,内存带宽受限O(N) | 显存带宽瓶颈 |
Post-Processing | 生成结束或达到停止条件 | 结果格式化/采样/安全过滤 | CPU密集型 | CPU负载 |
Prompt Chunking | 超长输入时(可选) | 分块处理prompt避免OOM | 分块并行 | 内存-计算平衡 |
Cache Management | KV缓存超过上下文窗口 | 滑动窗口/丢弃旧token | 缓存压缩 | 显存优化 |
2. 各阶段详解
(1). Prefill(预填充)阶段
- 触发时机:收到用户输入(prompt)后的首次计算。
- 核心任务:
- 对输入的所有token执行完整的前向计算,生成每个token的Key-Value(KV)矩阵。
- 将KV缓存到显存中,供后续生成阶段复用。
- 计算特点:
- 高并行性:一次性处理整个prompt序列(适合GPU并行计算)。
- 复杂度:与prompt长度平方相关(O(n²)),长prompt可能成为延迟瓶颈。
- 优化技术:
- FlashAttention:减少内存访问开销。
- 分块处理:将长prompt拆分为多个块(如vLLM的PagedAttention)。
示例:
# 伪代码:Prefill过程
prompt = "The weather is"
tokens = tokenizer.encode(prompt) # 假设输出 [1, 45, 12]
k_cache, v_cache = model.prefill(tokens) # 计算并缓存所有token的KV
(2). Decoding(自回归解码)阶段
-
触发时机:生成每个输出token时。
-
核心任务:
- 基于KV缓存生成单个token(计算当前token的Query,与缓存的Key做注意力计算)。
- 采样下一个token(如贪心搜索、核采样)。
- 将新token的KV追加到缓存。
-
瓶颈:
- 内存带宽限制:每个token需加载整个KV缓存(如7B模型约10GB/Token)。
- 示例:生成100token时,显存带宽成为主要限制。
-
加速方法:
- 推测解码:用小模型草案加速(如Medusa)。
- 连续批处理:合并多个请求(如TGI框架)。
-
关键操作:
- 计算当前token的
Q
矩阵,与缓存的K
做注意力计算。 - 采样下一个token(如贪心搜索、核采样)。
- 将新token的
K/V
追加到缓存。
示例:
- 计算当前token的
next_token = model.generate_step(k_cache, v_cache) # 生成一个token
k_cache, v_cache = update_cache(k_cache, v_cache, next_token) # 动态扩展缓存
(3). Post-Processing(后处理)阶段
- 触发时机:生成结束或达到停止条件(如max_length)。
- 常见操作:
- 格式化:添加标点、分段、去除重复文本。
- 采样调整:Temperature/Top-p控制多样性。
- 安全过滤:移除敏感有害内容(如NVIDIA NeMo Guardrails)。
示例:
# 后处理:过滤敏感词
from transformers import TextGenerationPipeline
generator = pipeline("text-generation", model="gpt-3")
output = generator("Explain nuclear energy:", safety_filter=True)
(4) Prompt Chunking(输入分块)
- 应用场景:处理超过上下文窗口的长文本(如GPT-4的32K tokens)。
- 实现方式:
- 物理分块:将输入拆分为多个段,分别Prefill。
- 逻辑分块:通过稀疏注意力(如Longformer的滑动窗口)减少计算量。
(5). Cache Management(缓存管理,可选)
- 场景:处理长文本时超出上下文窗口。
- 策略:
- 滑动窗口:保留最近N个token(如Llama 2的4K窗口)。
- 动态丢弃:根据注意力权重淘汰低重要性token。
示例:
if len(k_cache) > max_ctx_len:
k_cache = k_cache[-max_ctx_len:] # 滑动窗口截断
3. 阶段间的资源消耗对比
阶段 | 计算强度 | 内存压力 | 典型耗时比例(长文本) |
---|---|---|---|
Prefill | ★★★★★ | 高(显存峰值) | 60% |
Decoding | ★★☆☆☆ | 持续显存占用 | 35% |
Post-Process | ★☆☆☆☆ | CPU内存 | 5% |
4. 实际框架中的阶段控制
以 vLLM 推理框架为例:
from vllm import SamplingParams
# 1. Prefill + Decoding 配置
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate("Explain AI in 3 steps:", sampling_params)
# 2. 显式分块处理(长文本)
chunked_outputs = []
for chunk in split_long_text(input_text, chunk_size=512):
outputs = llm.generate(chunk, sampling_params)
chunked_outputs.extend(outputs)
# 3. 后处理
final_text = postprocess(chunked_outputs)
5. 优化方向
- Prefill加速:使用FlashAttention-2或Triton优化矩阵乘。
- Decoding优化:
- 连续批处理:合并多个请求的生成步骤(如TGI框架)。
- 推测解码:用小模型草案加速大模型推理(如Medusa)。
- 缓存压缩:
- 量化:将KV缓存转为FP8/INT4(如AWQ)。
- 共享缓存:多用户请求复用相似prompt的缓存。
总结
LLM推理是 多阶段协同 的过程:
- Prefill 决定首次响应速度。
- Decoding 影响生成吞吐量。
- 后处理 保障输出质量。
理解这些阶段有助于针对性优化(如降低长prompt的Prefill开销,或提高Decoding的并行度)。
- 优化延迟:长prompt场景下,Prefill耗时占比高,需针对性加速(如FlashAttention)。
- 提高吞吐量:Decoding阶段通过批处理提升GPU利用率。
- 资源管理:合理控制KV缓存内存,避免OOM错误。