深入理解stas00/ml-engineering项目中的推理技术
本文将从技术专家视角,系统性地介绍stas00/ml-engineering项目中涉及的机器学习推理相关概念和技术实现。
推理基础概念
预填充与解码阶段
在模型推理过程中,主要分为两个关键阶段:
-
预填充阶段(Prefill):处理完整的输入提示(prompt),一次性处理所有已知token,并缓存中间状态(KV缓存)。这一阶段延迟很低,即使处理1000个token的提示也能快速完成。
-
解码阶段(Decode):逐个生成新token,基于之前所有token(包括提示和已生成的新token)。这一阶段贡献了生成过程的大部分延迟,因为解码无法并行化。
在线与离线推理
在线推理(Online Inference):
- 实时处理用户查询
- 也称为部署或交互式推理
- 典型应用:聊天机器人、搜索引擎、REST API
- 需要持续运行的推理服务器
离线推理(Offline Inference):
- 批量处理大量提示
- 也称为批量推理
- 典型应用:基准评估、合成数据生成
- 通常不需要独立推理服务器
两种场景的优化重点不同:在线推理追求低首token时间(TTFT)和低延迟;离线推理追求高吞吐量。两者都需要关注预填充和解码阶段的综合token处理吞吐量,这直接决定了推理服务的总成本。
核心推理技术
基础技术
注意力机制变体
- MHA (Multi-Head Attention):多头注意力
- MQA (Multi-Query Attention):多查询注意力
- GQA (Grouped-Query Attention):分组查询注意力
- CLA (Cross-Layer Attention):跨层注意力
批处理技术
静态批处理(Static Batching):
- 简单地将前N个查询一起批处理
- 问题:已完成查询需等待最慢查询完成,增加延迟
连续批处理(Continuous Batching):
- 也称为动态批处理
- 已完成结果立即返回,空位由新查询填充
- 显著改善响应时间
分页注意力(Paged Attention)
- 类似操作系统内存分页机制
- 实现动态内存分配
- 防止内存碎片
- 提高计算设备内存利用率
解码方法
-
贪婪解码(Greedy Decoding):
- 总是选择概率最高的token
- 速度最快但可能错过更优token序列
- 可能导致重复循环
-
束搜索(Beam Search):
- 同时跟踪多个候选序列
- 束宽为3时,每步保留3个最优路径
- 最终选择整体概率最高的序列
- 比贪婪解码慢但质量更高
-
采样方法(Sampling):
- 引入可控随机性
- 主要变体:
- Top-K采样:从概率最高的K个token中随机选择
- Top-p采样(核采样):累积概率达到阈值p的token作为候选
温度参数(Temperature)
温度参数影响采样行为:
- t=0.0:等同于贪婪解码
- 0.0<t<1.0:减少随机性,偏向精确输出
- t=1.0:保持原始概率分布
- t>1.0:增加随机性,促进创造性输出
温度通过调整logits影响softmax概率分布:
scaled_logits = logits / temperature
probs = softmax(scaled_logits)
高级推理技术
引导文本生成(Guided Text Generation)
也称为结构化文本生成或辅助生成,确保模型输出符合特定格式要求。
实现原理:
- 为每个token定义允许的子集
- 从允许子集中选择概率最高的token
- 例如生成JSON时,强制模型按JSON语法规则生成
优势:
- 直接生成正确格式,无需后处理
- 可结合模式信息加速推理(预知固定token序列)
推测解码(Speculative Decoding)
也称为推测推理或辅助生成,使用小模型加速大模型推理:
- 小模型快速生成候选序列
- 大模型并行验证候选序列
- 匹配部分保留,不匹配部分重新生成
最适合:
- 输入接地的任务(翻译、摘要等)
- 贪婪解码模式
- 温度接近0的情况
替代方案:N-gram提示查找解码,无需小模型,直接从提示中查找匹配候选。
隐私保护推理(Privacy-Preserving Inference)
保护用户查询隐私的技术,主要解决方案:
-
全同态加密(FHE):
- 在加密数据上直接计算
- 保护数据隐私同时完成推理
-
安全多方计算(MPC):
- 多方协作计算
- 任何单方无法获取完整信息
性能指标术语
- TTFT (Time to First Token):首token时间
- TPOT (Time Per Output Token):每个输出token时间
- ITL (Inter-Token Latency):token间延迟
- QPS (Queries Per Second):每秒查询数
应用场景与技术选择
输入接地的任务
输出主要源自输入提示的任务,包括:
- 翻译
- 摘要
- 文档问答
- 多轮对话
- 代码编辑
- 语音识别(音频转文本)
这些任务特别适合推测解码等技术。
接地技术(Grounding)
为预训练模型提供训练时未接触的额外信息:
- 提示工程:通过精心设计的提示改变模型行为
- 检索增强生成(RAG):提供与提示相关的额外数据
- 领域微调:使模型适应新知识领域
接地可视为提供上下文,帮助模型生成更相关输出。在多模态应用中,图像或视频可作为文本提示的接地信息。
总结
本文系统梳理了stas00/ml-engineering项目中涉及的机器学习推理核心技术,从基础概念到高级优化技术,为开发者提供了全面的技术参考。理解这些概念和技术对于构建高效、可靠的推理系统至关重要,特别是在处理大规模语言模型时。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考