🎯 混合精度训练中的精度使用时机
| 训练阶段 | 使用精度 | 具体内容 | 原因 |
|---|---|---|---|
| 前向传播 | float16 | 输入数据、中间激活、输出logits、损失值 | 节省内存,提高速度 |
| 反向传播 | float16 | 损失梯度、激活梯度、权重梯度 | 节省内存,提高速度 |
| 梯度处理 | float32 | 梯度缩放、梯度裁剪、溢出检测 | 保证精度,稳定计算 |
| 参数更新 | float32 | 权重参数、优化器状态、参数更新 | 保持精度,稳定训练 |
🔄 精度转换的自动机制
1. autocast() 自动转换
with torch.cuda.amp.autocast():
# 前向传播: 自动使用float16
res = model(X)
# 反向传播: 自动使用float16
loss.backward()
2. GradScaler 手动转换
scaler.scale(loss).backward() # 损失缩放
scaler.unscale_(optimizer) # 梯度缩放
scaler.step(optimizer) # 参数更新
📊 精度使用的设计原则
1. 内存密集型操作 → float16
- 激活值、梯度值、中间结果、临时变量
- 目的: 节省内存,提高效率
2. 精度敏感操作 → float32
- 模型参数、优化器状态、参数更新、溢出检测
- 目的: 保证精度,稳定训练
3. 计算密集型操作 → float16
- 矩阵乘法、卷积操作、激活函数、损失计算
- 目的: 利用Tensor Core,提高速度
4. 稳定性敏感操作 → float32
- 梯度裁剪、学习率调度、权重初始化、数值检查
- 目的: 保证稳定性,避免数值问题
🚀 实际应用中的精度分配
前向传播阶段
- 输入数据: float16 (节省内存)
- 模型权重: float32 (保持精度)
- 中间激活: float16 (节省内存)
- 输出logits: float16 (节省内存)
- 损失值: float16 (节省内存)
反向传播阶段
- 损失梯度: float16 (节省内存)
- 激活梯度: float16 (节省内存)
- 权重梯度: float16 (节省内存)
- 梯度累积: float16 (节省内存)
梯度处理阶段
- 梯度缩放: float16 → float32 (精度转换)
- 梯度裁剪: float32 (保持精度)
- 溢出检测: float32 (保持精度)
参数更新阶段
- 权重参数: float32 (保持不变)
- 优化器状态: float32 (保持不变)
- 参数更新: float32 (保持精度)
- 缩放器状态: float32 (保持精度)
🔑 关键理解
- 前向传播和反向传播: 主要使用float16,节省内存和提高速度
- 梯度处理和参数更新: 主要使用float32,保证精度和稳定性
- 自动转换: autocast()自动处理精度转换
- 手动转换: GradScaler手动处理梯度缩放
- 设计原则: 内存密集型用float16,精度敏感型用float32
所以,混合精度训练中,不同阶段使用不同精度:
- 计算阶段(前向、反向)→ float16(效率优先)
- 更新阶段(梯度处理、参数更新)→ float32(精度优先)
这样既保证了训练效率,又保证了训练稳定性!
5486

被折叠的 条评论
为什么被折叠?



