PyTorch TorchTune 教程:使用 QAT 微调 Llama3 模型

PyTorch TorchTune 教程:使用 QAT 微调 Llama3 模型

torchtune A Native-PyTorch Library for LLM Fine-tuning torchtune 项目地址: https://gitcode.com/gh_mirrors/to/torchtune

概述

本文将详细介绍如何在 PyTorch TorchTune 框架中使用量化感知训练(Quantization-Aware Training, QAT)技术微调 Llama3 模型。QAT 是一种先进的模型量化技术,能够在训练过程中模拟量化效果,从而显著减少量化带来的精度损失。

量化感知训练(QAT)基础

什么是 QAT?

量化感知训练是一种在训练或微调过程中模拟量化数值的技术,目的是最终产生比简单的训练后量化(PTQ)更高质量的量化模型。在 QAT 过程中:

  1. 权重和/或激活值被"伪量化"(fake quantized)
  2. 数值被转换为量化后的形式,但仍保持原始数据类型(如 bfloat16)
  3. 模型能够适应量化噪声并相应调整权重

QAT 与 PTQ 的区别

# PTQ: 实际量化为 int8
x_q = (x_float / scale + zp).round().clamp(qmin, qmax).cast(int8)

# QAT: 伪量化,仍保持浮点类型
x_fq = (x_float / scale + zp).round().clamp(qmin, qmax)
x_fq = (x_fq - zp) * scale

在 Llama3 上应用 QAT

准备工作

  1. 安装 TorchTune 框架
  2. 下载 Llama3-8B 模型权重
  3. 熟悉 TorchTune 的基本使用方法

QAT 实现步骤

1. 准备模型
import copy
import torch
from torchao.quantization import quantize_
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    IntXQuantizationAwareTrainingConfig,
)
from torchtune.models.llama3 import llama3_8b

model = llama3_8b()
original_model = copy.deepcopy(model)

# 配置 int8 动态非对称每token激活 + int4 对称每组权重
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config)

# 准备QAT微调模型
quantize_(model, qat_config)
prepared_model = model
2. 模型结构变化

准备后的模型会将所有线性层替换为 FakeQuantizedLinear,模拟 int8 动态非对称每token激活 + int4 对称每组权重的数值特性。

3. 微调模型
# 正常进行微调训练
train_loop(prepared_model)
4. 转换为量化模型

微调完成后,将伪量化模型转换为真正的量化模型:

from torchao.quantization.qat import FromIntXQuantizationAwareTrainingConfig
from torchao.quantization import Int8DynamicActivationInt4WeightConfig

# 转换模型
quantize_(prepared_model, FromIntXQuantizationAwareTrainingConfig())
quantize_(prepared_model, Int8DynamicActivationInt4WeightConfig(group_size=32))

converted_model = prepared_model

TorchTune 中的 QAT 微调配方

配置 QAT 微调

  1. 复制默认 QAT 配置:

    tune cp llama3/8B_qat_full custom_8B_qat_full.yaml
    
  2. 修改配置参数:

    dataset:
      _component_: torchtune.datasets.text_completion_dataset
      source: allenai/c4
      column: text
      name: en
      split: train
    
    epochs: 1
    max_steps_per_epoch: 2000
    fake_quant_after_n_steps: 1000  # 前1000步禁用伪量化
    

运行 QAT 微调

tune run --nnodes 1 --nproc_per_node 6 qat_distributed --config custom_8B_qat_full.yaml

注意

  • 需要至少6个GPU,每个GPU至少有80GB VRAM
  • QAT 会引入约30%的训练速度下降
  • 使用激活检查点时,每个GPU的内存占用增加小于5GB

量化 QAT 模型

QAT 微调后得到的是未量化的 bfloat16 模型,需要额外进行量化步骤:

  1. 复制量化配置:

    tune cp quantization custom_quantization.yaml
    
  2. 修改配置:

    model:
      _component_: torchtune.models.llama3.llama3_8b
    
    checkpointer:
      _component_: torchtune.training.FullModelMetaCheckpointer
      checkpoint_dir: <your QAT checkpoint dir>
      checkpoint_files: [ft-model-00001-of-00001.bin]
      output_dir: <your QAT checkpoint dir>
      model_type: LLAMA3
    
    quantizer:
      _component_: torchtune.training.quantization.Int8DynActInt4WeightQATQuantizer
      groupsize: 256
    
  3. 运行量化:

    tune run quantize --config custom_quantization.yaml
    

评估量化模型

使用 EleutherAI 评估工具对量化模型进行评估:

  1. 复制评估配置:

    tune cp eleuther_evaluation custom_eleuther_evaluation.yaml
    
  2. 修改配置:

    model:
      _component_: torchtune.models.llama3.llama3_8b
    
    checkpointer:
      _component_: torchtune.training.FullModelTorchTuneCheckpointer
      checkpoint_dir: <your quantized model checkpoint dir>
      checkpoint_files: [ft-model-00001-of-00001-8da4w.bin]
      output_dir: <your quantized model checkpoint dir>
      model_type: LLAMA3
    
    tasks: ["hellaswag", "wikitext"]
    
    quantizer:
      _component_: torchtune.training.quantization.Int8DynActInt4WeightQuantizer
      groupsize: 256
    
  3. 运行评估:

    tune run eleuther_eval --config my_eleuther_evaluation.yaml
    

结果对比

QAT 量化模型结果

|  Tasks  |Version|Filter|n-shot|    Metric     |Value |   |Stderr|
|---------|------:|------|-----:|---------------|-----:|---|------|
|wikitext |      2|none  |     0|word_perplexity|9.9148|±  |N/A   |
|hellaswag|      1|none  |     0|acc_norm       |0.7536|±  |0.0043|

PTQ 量化模型结果

|  Tasks  |Version|Filter|n-shot|    Metric     | Value |   |Stderr|
|---------|------:|------|-----:|---------------|------:|---|------|
|wikitext |      2|none  |     0|word_perplexity|10.7735|±  |N/A   |
|hellaswag|      1|none  |     0|acc_norm       | 0.7390|±  |0.0044|

原始浮点模型结果

|  Tasks  |Version|Filter|n-shot|    Metric     |Value |   |Stderr|
|---------|------:|------|-----:|---------------|-----:|---|------|
|wikitext |      2|none  |     0|word_perplexity|8.7248|±  |N/A   |
|hellaswag|      1|none  |     0|acc_norm       |0.7610|±  |0.0043|

结论:QAT 相比 PTQ 能显著减少量化带来的精度损失,在 hellaswag 任务中,QAT 仅损失 0.74% 准确率,而 PTQ 损失 2.20%。

最佳实践

  1. 延迟伪量化:前1000步禁用伪量化,让权重先稳定
  2. 配置一致性:确保微调和量化使用相同的量化器配置
  3. 资源规划:QAT 需要更多计算资源和时间,提前做好规划
  4. 评估对比:始终与原始浮点模型和PTQ模型进行对比评估

通过本教程,您应该已经掌握了在 TorchTune 中使用 QAT 技术微调 Llama3 模型的完整流程,并能够评估量化模型的性能表现。

torchtune A Native-PyTorch Library for LLM Fine-tuning torchtune 项目地址: https://gitcode.com/gh_mirrors/to/torchtune

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

滑隽蔚Maia

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值