文章目录
PyTorch 学习率调度器(LR Scheduler)
1. 一句话定义
每过一段时间 / 满足某条件,自动按规则修改优化器学习率的工具。
2. 通用使用套路
optimizer = torch.optim.Adam(model.parameters(), lr=初始LR)
scheduler = XXXLR(optimizer, ...) # 选下面任意一种
for epoch in range(EPOCH):
train(...)
val_loss = validate(...)
optimizer.step() # ① 先更新参数
scheduler.step(val_loss) # ② 再调度LR(ReduceLROnPlateau需传loss)
顺序:先 optimizer.step() → 再 scheduler.step(),否则报警告。
3. 内置调度器对比速览
| 调度器 | 触发规则 | 主要参数 | 参数解释 | 典型场景 |
|---|---|---|---|---|
| LambdaLR | 自定义函数 f(epoch) 返回乘数 | lr_lambda, last_epoch | lr_lambda: 接收 epoch,返回 LR 乘数;last_epoch: 重启训练时设为上次 epoch | warmup、分段线性 |
| StepLR | 固定每 step_size epoch 降一次 | step_size, gamma, last_epoch | step_size: 隔多少 epoch 降;gamma: 乘性衰减系数 | 常规“等间隔”下降 |
| MultiStepLR | 指定里程碑 epoch 列表降 | milestones, gamma, last_epoch | milestones: List,到这些 epoch 就 ×gamma | 训练中期多段下降 |
| CosineAnnealingLR | 余弦曲线从初始→η_min | T_max, eta_min, last_epoch | T_max: 半个余弦周期长度;eta_min: 最小 LR | 退火、cosine 重启 |
| ReduceLROnPlateau | 监控指标停止改善时降 | mode, factor, patience, threshold, cooldown, min_lr | 见下方详注 | 验证 loss/acc 卡住时 |
ReduceLROnPlateau 参数详注
mode='min'或'max':指标越小/越大越好factor=0.1:新 LR = 旧 LR × factorpatience=3:连续 3 次 epoch 无改善才降threshold=0.01:改善幅度小于阈值视为无改善cooldown=1:降 LR 后冻结监控的 epoch 数min_lr=1e-6:下限,降到此值不再降
4. 各调度器最小模板
① LambdaLR(线性 warmup)
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: epoch / 5 if epoch < 5 else 1)
② StepLR
scheduler = StepLR(optimizer, step_size=2, gamma=0.1) # 每 2 epoch ×0.1
③ MultiStepLR
scheduler = MultiStepLR(optimizer, milestones=[2, 6], gamma=0.1)
④ CosineAnnealingLR
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
⑤ ReduceLROnPlateau(必须传指标)
scheduler = ReduceLROnPlateau(
optimizer, mode='min', factor=0.1, patience=3,
threshold=0.01, cooldown=1, min_lr=1e-6)
val_loss = validate(...)
scheduler.step(val_loss) # ← 记得传指标
5. 常用调试 API
scheduler.get_last_lr() # 当前实际 LR 列表(每个 param_group)
scheduler.last_epoch # 已完成的 epoch 计数(从 0 开始)
6. 易踩坑 Top-3
- 先
optimizer.step()再scheduler.step()
否则报警告 “Detected call oflr_scheduler.step()beforeoptimizer.step()”。 - ReduceLROnPlateau 必须传监控值
不传 → RuntimeError。 - Lambda/MultiStep 等无需监控值,传了 → TypeError。
7. 速记口诀
“优化先迈步,调度再跟进;Plateau 传 loss,其余不用问。”
&spm=1001.2101.3001.5002&articleId=151184689&d=1&t=3&u=cfc22009b628495fa9457ffc3fe7d11e)
735

被折叠的 条评论
为什么被折叠?



