torch.optim.lr_scheduler.StepLR()
定义:PyTorch 中一种常用的学习率调度器
作用:在训练过程中按照固定间隔调整学习率。
原理:通过在每个指定的 epoch 数(step_size
)后将学习率乘以一个衰减因子(gamma
)来实现学习率的动态调整。
StepLR 的主要参数
optimizer
:需要包装的优化器对象,例如torch.optim.SGD
或torch.optim.Adam
。step_size
:学习率调整的间隔,即每隔多少个 epoch 调整一次学习率。gamma
:学习率衰减的乘法因子,默认值为 0.1,表示每次调整时学习率会乘以 0.1。last_epoch
:上一个 epoch 的索引,默认为 -1,表示从头开始训练。verbose
:是否打印学习率更新的信息,True
时会在每次更新时打印当前学习率。
StepLR 的工作原理
- 每经过
step_size
个 epoch,学习率会乘以gamma
。例如,初始学习率为 0.1,step_size=30
,gamma=0.1
,那么在第 30 个 epoch 后,学习率会变为 0.01,第 60 个 epoch 后变为 0.001。
例子
import torch
import torch.optim as optim
import torch.nn as nn
# 定义一个简单的模型
model = nn.Linear(10, 1)
# 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.1)
# 定义 StepLR 调度器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
# 模拟训练过程
num_epochs = 50
for epoch in range(num_epochs):
# 模拟训练代码
optimizer.zero_grad()
# 假设的训练逻辑
optimizer.step()
# 更新学习率
scheduler.step()
# 打印当前学习率
print(f"Epoch {epoch + 1}, Learning Rate: {scheduler.get_last_lr()[0]}")
在这个例子中,学习率会在每 10 个 epoch 后乘以 0.1。
适用场景
- 大规模数据集:当训练周期较长时,逐步降低学习率可以稳定训练过程,避免学习率过高导致的训练不稳定。
- 需要逐步调整学习率的场景:适用于模型在训练初期需要较大学习率快速收敛,而在后期需要较小学习率进行精细调整。
注意事项
step_size
的选择:间隔设置过短会导致学习率下降过快,影响训练速度;间隔过长则可能无法充分利用学习率调整带来的稳定性。- 与其他调度器结合使用:可以将
StepLR
与其他调度器(如ReduceLROnPlateau
)结合使用,以实现更灵活的学习率调整策略。
通过合理配置 StepLR
,可以有效提升模型的训练效果和收敛速度。