大模型训练中:混合精度的使用时机

🎯 混合精度训练中的精度使用时机

训练阶段使用精度具体内容原因
前向传播float16输入数据、中间激活、输出logits、损失值节省内存,提高速度
反向传播float16损失梯度、激活梯度、权重梯度节省内存,提高速度
梯度处理float32梯度缩放、梯度裁剪、溢出检测保证精度,稳定计算
参数更新float32权重参数、优化器状态、参数更新保持精度,稳定训练

🔄 精度转换的自动机制

1. autocast() 自动转换
with torch.cuda.amp.autocast():
    # 前向传播: 自动使用float16
    res = model(X)
    # 反向传播: 自动使用float16
    loss.backward()
2. GradScaler 手动转换
scaler.scale(loss).backward()      # 损失缩放
scaler.unscale_(optimizer)         # 梯度缩放
scaler.step(optimizer)             # 参数更新

📊 精度使用的设计原则

1. 内存密集型操作 → float16
  • 激活值、梯度值、中间结果、临时变量
  • 目的: 节省内存,提高效率
2. 精度敏感操作 → float32
  • 模型参数、优化器状态、参数更新、溢出检测
  • 目的: 保证精度,稳定训练
3. 计算密集型操作 → float16
  • 矩阵乘法、卷积操作、激活函数、损失计算
  • 目的: 利用Tensor Core,提高速度
4. 稳定性敏感操作 → float32
  • 梯度裁剪、学习率调度、权重初始化、数值检查
  • 目的: 保证稳定性,避免数值问题

🚀 实际应用中的精度分配

前向传播阶段
  • 输入数据: float16 (节省内存)
  • 模型权重: float32 (保持精度)
  • 中间激活: float16 (节省内存)
  • 输出logits: float16 (节省内存)
  • 损失值: float16 (节省内存)
反向传播阶段
  • 损失梯度: float16 (节省内存)
  • 激活梯度: float16 (节省内存)
  • 权重梯度: float16 (节省内存)
  • 梯度累积: float16 (节省内存)
梯度处理阶段
  • 梯度缩放: float16 → float32 (精度转换)
  • 梯度裁剪: float32 (保持精度)
  • 溢出检测: float32 (保持精度)
参数更新阶段
  • 权重参数: float32 (保持不变)
  • 优化器状态: float32 (保持不变)
  • 参数更新: float32 (保持精度)
  • 缩放器状态: float32 (保持精度)

🔑 关键理解

  1. 前向传播和反向传播: 主要使用float16,节省内存和提高速度
  2. 梯度处理和参数更新: 主要使用float32,保证精度和稳定性
  3. 自动转换: autocast()自动处理精度转换
  4. 手动转换: GradScaler手动处理梯度缩放
  5. 设计原则: 内存密集型用float16,精度敏感型用float32

所以,混合精度训练中,不同阶段使用不同精度

  • 计算阶段(前向、反向)→ float16(效率优先)
  • 更新阶段(梯度处理、参数更新)→ float32(精度优先)

这样既保证了训练效率,又保证了训练稳定性!

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值