PyTorch TorchTune 教程:使用 QAT 微调 Llama3 模型
概述
本文将详细介绍如何在 PyTorch TorchTune 框架中使用量化感知训练(Quantization-Aware Training, QAT)技术微调 Llama3 模型。QAT 是一种先进的模型量化技术,能够在训练过程中模拟量化效果,从而显著减少量化带来的精度损失。
量化感知训练(QAT)基础
什么是 QAT?
量化感知训练是一种在训练或微调过程中模拟量化数值的技术,目的是最终产生比简单的训练后量化(PTQ)更高质量的量化模型。在 QAT 过程中:
- 权重和/或激活值被"伪量化"(fake quantized)
- 数值被转换为量化后的形式,但仍保持原始数据类型(如 bfloat16)
- 模型能够适应量化噪声并相应调整权重
QAT 与 PTQ 的区别
# PTQ: 实际量化为 int8
x_q = (x_float / scale + zp).round().clamp(qmin, qmax).cast(int8)
# QAT: 伪量化,仍保持浮点类型
x_fq = (x_float / scale + zp).round().clamp(qmin, qmax)
x_fq = (x_fq - zp) * scale
在 Llama3 上应用 QAT
准备工作
- 安装 TorchTune 框架
- 下载 Llama3-8B 模型权重
- 熟悉 TorchTune 的基本使用方法
QAT 实现步骤
1. 准备模型
import copy
import torch
from torchao.quantization import quantize_
from torchao.quantization.qat import (
FakeQuantizeConfig,
IntXQuantizationAwareTrainingConfig,
)
from torchtune.models.llama3 import llama3_8b
model = llama3_8b()
original_model = copy.deepcopy(model)
# 配置 int8 动态非对称每token激活 + int4 对称每组权重
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config)
# 准备QAT微调模型
quantize_(model, qat_config)
prepared_model = model
2. 模型结构变化
准备后的模型会将所有线性层替换为 FakeQuantizedLinear
,模拟 int8 动态非对称每token激活 + int4 对称每组权重的数值特性。
3. 微调模型
# 正常进行微调训练
train_loop(prepared_model)
4. 转换为量化模型
微调完成后,将伪量化模型转换为真正的量化模型:
from torchao.quantization.qat import FromIntXQuantizationAwareTrainingConfig
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
# 转换模型
quantize_(prepared_model, FromIntXQuantizationAwareTrainingConfig())
quantize_(prepared_model, Int8DynamicActivationInt4WeightConfig(group_size=32))
converted_model = prepared_model
TorchTune 中的 QAT 微调配方
配置 QAT 微调
-
复制默认 QAT 配置:
tune cp llama3/8B_qat_full custom_8B_qat_full.yaml
-
修改配置参数:
dataset: _component_: torchtune.datasets.text_completion_dataset source: allenai/c4 column: text name: en split: train epochs: 1 max_steps_per_epoch: 2000 fake_quant_after_n_steps: 1000 # 前1000步禁用伪量化
运行 QAT 微调
tune run --nnodes 1 --nproc_per_node 6 qat_distributed --config custom_8B_qat_full.yaml
注意:
- 需要至少6个GPU,每个GPU至少有80GB VRAM
- QAT 会引入约30%的训练速度下降
- 使用激活检查点时,每个GPU的内存占用增加小于5GB
量化 QAT 模型
QAT 微调后得到的是未量化的 bfloat16 模型,需要额外进行量化步骤:
-
复制量化配置:
tune cp quantization custom_quantization.yaml
-
修改配置:
model: _component_: torchtune.models.llama3.llama3_8b checkpointer: _component_: torchtune.training.FullModelMetaCheckpointer checkpoint_dir: <your QAT checkpoint dir> checkpoint_files: [ft-model-00001-of-00001.bin] output_dir: <your QAT checkpoint dir> model_type: LLAMA3 quantizer: _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer groupsize: 256
-
运行量化:
tune run quantize --config custom_quantization.yaml
评估量化模型
使用 EleutherAI 评估工具对量化模型进行评估:
-
复制评估配置:
tune cp eleuther_evaluation custom_eleuther_evaluation.yaml
-
修改配置:
model: _component_: torchtune.models.llama3.llama3_8b checkpointer: _component_: torchtune.training.FullModelTorchTuneCheckpointer checkpoint_dir: <your quantized model checkpoint dir> checkpoint_files: [ft-model-00001-of-00001-8da4w.bin] output_dir: <your quantized model checkpoint dir> model_type: LLAMA3 tasks: ["hellaswag", "wikitext"] quantizer: _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer groupsize: 256
-
运行评估:
tune run eleuther_eval --config my_eleuther_evaluation.yaml
结果对比
QAT 量化模型结果
| Tasks |Version|Filter|n-shot| Metric |Value | |Stderr|
|---------|------:|------|-----:|---------------|-----:|---|------|
|wikitext | 2|none | 0|word_perplexity|9.9148|± |N/A |
|hellaswag| 1|none | 0|acc_norm |0.7536|± |0.0043|
PTQ 量化模型结果
| Tasks |Version|Filter|n-shot| Metric | Value | |Stderr|
|---------|------:|------|-----:|---------------|------:|---|------|
|wikitext | 2|none | 0|word_perplexity|10.7735|± |N/A |
|hellaswag| 1|none | 0|acc_norm | 0.7390|± |0.0044|
原始浮点模型结果
| Tasks |Version|Filter|n-shot| Metric |Value | |Stderr|
|---------|------:|------|-----:|---------------|-----:|---|------|
|wikitext | 2|none | 0|word_perplexity|8.7248|± |N/A |
|hellaswag| 1|none | 0|acc_norm |0.7610|± |0.0043|
结论:QAT 相比 PTQ 能显著减少量化带来的精度损失,在 hellaswag 任务中,QAT 仅损失 0.74% 准确率,而 PTQ 损失 2.20%。
最佳实践
- 延迟伪量化:前1000步禁用伪量化,让权重先稳定
- 配置一致性:确保微调和量化使用相同的量化器配置
- 资源规划:QAT 需要更多计算资源和时间,提前做好规划
- 评估对比:始终与原始浮点模型和PTQ模型进行对比评估
通过本教程,您应该已经掌握了在 TorchTune 中使用 QAT 技术微调 Llama3 模型的完整流程,并能够评估量化模型的性能表现。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考