torch.cuda.amp包提供便利的混合精度操作方法,即对应不同的层采取不一样的浮点数精度(torch.float32、torch.float16、torch.bfloat16),以减小模型计算时间与内存资源为目的。一般生产情况下, torch.autocast 与 torch.cuda.amp.GradScaler配合完成自动混合精度的功能。
在训练loop中官方推荐方式如下:
for epoch in range(xxx):
for input, target in zip(data, targets):
with torch.autocast(device_type=device, dtype=torch.float16):
#1 step
forward pass
#2 step
compute loss
#3 step
#Scales loss. Calls ``backward()`` on scaled loss to create scaled gradients.
#Gradient scaling helps prevent gradients with small magnitudes from flushing to zero (“underflowing”) when training with mixed precision
scaler.scale(loss).backward()
#4 step
# Unscales the gradients of optimizer's assigned parameters in-place
scaler.unscale_(opt)
#5 step
# Since the gradients of optimizer's assigned parameters are now unscaled, clips as usual.
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=0.1)
#6 step
# ``scaler.step()`` first unscales the gradients of the optimizer's assigned parameters.
# If these gradients do not contain ``inf``s or ``NaN``s, optimizer.step() is then called,otherwise, optimizer.step() is skipped.
scaler.step(opt)
#7 step
# Updates the scale for next iteration.
scaler.update()
#8 step
opt.zero_grad()