调整学习率通常与实际算法同样重要。PyTorch 提供了多种学习率调度器,每种调度器都有其独特的优点和适用场景。
常用调整器类型
1. StepLR
- 描述: 每隔一定数量的 epochs,将学习率乘以一个固定的因子。
- 优点: 实现简单,适用于学习率需要按照固定步长减小的情况。
- 缺点: 步长和衰减率需要手动设置,可能不够灵活。
2. MultiStepLR
- 描述: 类似于 StepLR,但允许在一系列 epoch 上定义不同的衰减点。
- 优点: 比 StepLR 更灵活,可以自定义多个下降点。
- 缺点: 同样需要手动设置衰减点和衰减率。
3. ExponentialLR
- 描述: 每个 epoch 将学习率乘以一个固定的衰减率,实现指数衰减。
- 优点: 实现简单,适用于需要平滑减少学习率的情况。
- 缺点: 需要手动设置衰减率,且衰减过程是连续的,不如阶梯式衰减灵活。
4. CosineAnnealingLR
- 描述: 按照余弦函数周期性地调整学习率,有利于在训练的后期进行精细调整。
- 优点: 可以避免局部最小值,周期性的调整有助于探索更多的参数空间。
- 缺点: 调整周期和学习率范围需要手动设置。
5. ReduceLROnPlateau
- 描述: 根据监测的指标自动调整学习率,当性能停止改善时降低学习率。
- 优点: 基于模型的实际表现动态调整学习率,适用于不确定何时调整学习率的情况。
- 缺点: 可能需要多个 epoch 才能调整学习率,反应较慢。
6. OneCycleLR
- 描述: 在训练过程中先增加然后减少学习率,适用于短时间训练高性能模型。
- 优点: 可以快速收敛,适用于训练周期较短的任务。
- 缺点: 需要精细调整学习率的最大值和最小值。
ReduceLROnPlateau调整器
参数解释
- mode:(‘min’ 或 ‘max’): 如果设置为 ‘min’,学习率会在监控的指标停止下降时降低;如果设置为 ‘max’,则在指标停止上升时降低,通常设置为min。
- factor:学习率降低的比率。新的学习率等于旧的学习率乘以这个因子。
- patience:在降低学习率之前允许指标停止改进的 epoch 数。
- verbose:设置为True会在更改学习率时打印信息。
- min_lr:学习率的下限。
- threshold:对于衡量新的最优值的改善,这是一个阈值。
- threshold_mode: (‘rel’ 或 ‘abs’): ‘rel’ 表示相对改善,‘abs’ 表示绝对改善。
- cooldown:在学习率被降低后,增加一些 epoch 的“冷却时间”,在这段时间内不会进一步降低学习率。
- eps:为了数值稳定性而加到学习率上的最小增量。
使用示例
初始化
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5, verbose=True, min_lr=0.00001)
调用
scheduler.step(val_loss)
在每个训练 epoch 的末尾,使用验证集的性能指标来更新调度器。这里,val_loss 是在当前 epoch 后,模型在验证集上的平均损失。
完整示例
import torch
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
model = MyModel() # 用您的模型替换此处
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)
for epoch in range(num_epochs):
model.train()
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
model.eval()
val_loss = 0
with torch.no_grad():
for data, target in val_loader:
output = model(data)
val_loss += criterion(output, target).item()
val_loss /= len(val_loader)
print(f'Epoch: {epoch+1}, Val Loss: {val_loss}')
# 更新调度器
scheduler.step(val_loss)
1196

被折叠的 条评论
为什么被折叠?



