超参数
一、warmup_steps学习率预热
学习率预热是在整个训练流程中对学习率进行动态调整的一个阶段。在开始的 warmup_steps (这里是 500 步)内,学习率从一个较小的值逐渐增加到预设的初始学习率。这个过程是与模型的参数更新同步进行的,也就是在训练的前 500 步中,模型一边更新参数,一边逐渐提升学习率。
- 缓解梯度爆炸或梯度消失问题
在训练初期,模型的参数是随机初始化的。如果直接使用较大的学习率,可能会导致梯度在反向传播过程中变得非常大(梯度爆炸)或者非常小(梯度消失)。梯度爆炸会使模型参数更新幅度过大,导致模型无法收敛;梯度消失则会使模型参数更新几乎停滞,训练难以进行。
学习率预热通过在开始的 warmup_steps 步数内,将学习率从一个较小的值逐渐增加到初始学习率,让模型在训练初期以较小的步长更新参数,使模型有足够的时间来适应数据分布,从而缓解梯度爆炸或梯度消失的问题,使训练过程更加稳定。 - 帮助模型更好地探索参数空间
在训练开始时,模型对数据的分布还没有充分的了解。如果学习率过大,模型可能会过快地陷入局部最优解,而无法充分探索整个参数空间。通过学习率预热,在训练初期使用较小的学习率,模型可以更细致地探索参数空间,找到更优的参数组合,从而提高模型的泛化能力。 - 适应不同的数据分布
当使用预训练模型进行微调或者处理复杂的数据分布时,模型可能需要一些时间来适应新的数据。学习率预热可以让模型在开始阶段以较小的学习率进行调整,逐步适应新的数据分布,避免因学习率过大而破坏预训练模型的权重,从而更好地利用预训练的优势。
那么这个较小的值是怎样得到的?在代码实现里,这个较小值通常会在学习率调度器的实现中被设定。以 PyTorch 为例,WarmupLinearSchedule 或 WarmupCosineSchedule 这样的自定义学习率调度器可能会在初始化时根据预热步数和初始学习率来计算起始学习率。假设预热阶段的起始学习率是初始学习率的 1/warmup_steps,即如果初始学习率是 0.001,warmup_steps 为 500,那么起始学习率就是 0.001 / 500 = 0.000002。随着训练步数的增加,学习率会线性增长,直到达到初始学习率。
二、gradient_accumulation_steps梯度累积
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
if (step + 1) % args.gradient_accumulation_steps == 0:
losses.update(loss.item() * args.gradient_accumulation_steps)
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
scheduler.step()
optimizer.step()
optimizer.zero_grad()
global_step += 1
三、amp自动混合精度
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O2',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument('--loss_scale', type=float, default=0,
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n")
第一行