Environment
- OS: macOS Mojave
- Python version: 3.7
- PyTorch version: 1.4.0
- IDE: PyCharm
文章目录
0. 写在前面
PyTorch 在 torch.optim.lr_scheduler 中提供了十种调整学习率的类,它们的基类为 torch.optim.lr_scheduler._LRScheduler。
这里记常用的一些。详尽的参考 PyTorch 官方文档。
1. StepLR
StepLR 类,按固定间隔调整学习率
l r = l r × g a m m a lr = lr \times gamma lr=lr×gamma
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR
# config
learning_rate = 1.
num_epochs = 50
# create weight, optimizer and scheduler
weight = torch.randn((1,), dtype=torch.float32)
optimizer = SGD([weight], lr=learning_rate, momentum=0.9)
scheduler = StepLR(
optimizer=optimizer,
step_size=10, # 设定间隔数
gamma=0.1, # 系数
last_epoch=-1
)
# train-like iteration
lrs, epochs = [], []
for epoch in range(num_epochs):
lrs.append(scheduler.get_lr())
epochs.append(epoch)
pass # iter and train here
scheduler.step()
# visualize
sns.set_style('whitegrid')
plt.figure(figsize=(10, 6))
plt.plot(epochs, lrs, label='StepLR')
plt.legend()
plt.show()

以 StepLR 类为例,记一下 scheduler 重要的属性和方法
1.1 scheduler 常用的属性
scheduler.optimizer关联优化器
print(scheduler.optimizer == optimizer) # True
print(id(scheduler.optimizer) == id(optimizer)) # True
scheduler.base_lrs存储各组的初始学习率
print(scheduler.base_lrs) # [1.0]
scheduler.last_epoch记录 epoch 数,应该是可以用在断点续训练中(但尚未用到)
print(scheduler.last_epoch) # 0
1.2 scheduler 常用的方法
scheduler.step()更新至下一个 epoch 的学习率scheduler.get_lr()获得当前各组的学习率,个别优化器类的实例没有该方法(如ReduceLROnPlateau)
PyTorch学习:调整学习率策略详解

本文是PyTorch学习笔记的第四部分,主要介绍了如何调整学习率,包括StepLR、MultiStepLR、ExponentialLR、ReduceLROnPlateau和LambdaLR五种策略。StepLR按固定间隔调整学习率,MultiStepLR在给定间隔调整,ExponentialLR指数衰减,ReduceLROnPlateau根据监控指标变化调整,LambdaLR则允许对不同参数组设置定制策略。
最低0.47元/天 解锁文章
7722

被折叠的 条评论
为什么被折叠?



