Environment
- OS: macOS Mojave
- Python version: 3.7
- PyTorch version: 1.4.0
- IDE: PyCharm
文章目录
0. 写在前面
PyTorch 在 torch.optim.lr_scheduler
中提供了十种调整学习率的类,它们的基类为 torch.optim.lr_scheduler._LRScheduler
。
这里记常用的一些。详尽的参考 PyTorch 官方文档。
1. StepLR
StepLR
类,按固定间隔调整学习率
l r = l r × g a m m a lr = lr \times gamma lr=lr×gamma
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import StepLR
# config
learning_rate = 1.
num_epochs = 50
# create weight, optimizer and scheduler
weight = torch.randn((1,), dtype=torch.float32)
optimizer = SGD([weight], lr=learning_rate, momentum=0.9)
scheduler = StepLR(
optimizer=optimizer,
step_size=10, # 设定间隔数
gamma=0.1, # 系数
last_epoch=-1
)
# train-like iteration
lrs, epochs = [], []
for epoch in range(num_epochs):
lrs.append(scheduler.get_lr())
epochs.append(epoch)
pass # iter and train here
scheduler.step()
# visualize
sns.set_style('whitegrid')
plt.figure(figsize=(10, 6))
plt.plot(epochs, lrs, label='StepLR')
plt.legend()
plt.show()
以 StepLR
类为例,记一下 scheduler
重要的属性和方法
1.1 scheduler 常用的属性
scheduler.optimizer
关联优化器
print(scheduler.optimizer == optimizer) # True
print(id(scheduler.optimizer) == id(optimizer)) # True
scheduler.base_lrs
存储各组的初始学习率
print(scheduler.base_lrs) # [1.0]
scheduler.last_epoch
记录 epoch 数,应该是可以用在断点续训练中(但尚未用到)
print(scheduler.last_epoch) # 0