单设备高效训练
在学习分布式训练之前,我们需要了解一些单个设备高效训练的实用技术,包括混精度训练、梯度检查点等。这些技术通过优化内存利用率、加快训练速度或两者兼而有之来提高模型训练效率。即使在拥有多块设备的机器上,这些方法仍然有效,还可以结合分布式训练的其他方法进一步优化训练过程。
混精度训练
- 在大模型训练背景下,混合精度训练已成为普遍做法,能显著提升训练速度数倍而不显著影响模型性能
- 传统科学计算追求高精度(如FP128/FP64),但深度学习作为高维函数拟合问题,不需要过高精度
- 低精度计算带来显著速度提升:在英伟达A00 SXM和H00 SXM中,FP16运算能力理论峰值是FP32的近30倍
- 训练后期,激活函数梯度可能非常小,FP16有限精度范围可能导致更新无效,此时需要混精度训练
混精度训练包含两个核心部分:
- 半精度计算:使用FP16进行计算密集型的前向传播和反向传播
- 权重备份:维护FP32精度的主权重副本用于参数更新
训练流程:
- 准备两套模型状态:FP32(优化器状态和模型参数)和FP16(模型参数)
- 前向传播和反向传播使用FP16参数
- 参数更新时将梯度与学习率相乘,更新到FP32参数
- 将更新后的FP32参数拷贝为FP16参数进行下一轮训练
优势:
- 计算密集型操作使用半精度显著提升速度
- FP16存储激活值在大批量训练时节省内存
- 分布式环境下FP16减少梯度通信量
损失缩放(Loss Scale)
- 解决FP16下溢问题的关键技术
- 训练后期梯度特别小时,通过缩放损失函数(链式法则作用到梯度)将梯度平移到FP16有效表示范围
- 缩放因子通常在8-32k之间,非所有网络必需
PyTorch实现示例:
from torch.cuda.amp import GradScaler, autocast
scaler = GradScaler()
with autocast():
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
动态损失缩放算法:
- 从高缩放因子开始(如2^24)
- 迭代中检查梯度是否溢出(Infs/Nans)
- 无溢出则保持缩放因子继续迭代
- 检测到溢出则缩放因子减半,直到不溢出
- 训练后期允许提高缩放因子防止下溢
梯度检查点(Gradient Checkpointing)
内存消耗组成
模型状态(必须存储):
- 优化器状态(如Adam中的Momentum和Variance)
- 梯度(G)
- 模型参数(W)
剩余状态(非必须但训练中产生):
- 激活值(反向传播时使用,可重新计算)
- 临时存储(如All-Reduce操作产生的)
- 碎片化存储空间(可通过内存整理优化)
技术原理
- 传统方法存储所有中间激活值消耗大量内存
- 梯度检查点周期性只存储关键层输入输出
- 反向传播需要时,利用存储的关键层输出重新计算激活值
优势:
- 显著减少内存使用,使大模型能在有限内存设备上训练
- 被Turing-NLG 17.2B和GPT-3 175B等大模型采用
代价:
- 增加部分额外计算开销(重新计算激活值)
- 通常与其他内存优化技术结合使用
梯度累积
适用于内存受限或需要模拟大批量训练的场景。通过在多个批次上累积梯度后执行一次更新:
# 前向传播和反向传播
outputs = model(inputs)
loss = criterion(outputs, labels).backward()
# 梯度累积
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
优势:
- 内存优化:允许使用更大的等效批量大小而不增加内存消耗
- 通信优化:分布式训练中减少参数更新频率,降低通信负载
- 训练稳定性:等效于使用更大批量,可能提高训练稳定性