【目录】
什么是优化器
optimizer的属性
optimizer的方法
1、什么是优化器
-
优化器在机器学习中的作用
在损失函数中会得到一个loss值,即模型输出与真实值之间的一个差异
对于loss值我们通常会采取pytorch中的自动求导autograd模块去求取模型当中的参数的梯度grad
优化器拿到梯度grad后,进行一系列的优化策略去更新模型参数,使得loss值不断下降
-
导数、方向导数与梯度
学习参数通常是指权值或者偏置bias,更新的策略在神经网络中通常都会采用梯度下降方法
- 方向导数
两个自变量x,y,输出值为z可以认为是山坡的高度
多元函数的导数都是偏导,对x的偏导就是,固定y,求在x方向上的变化率;而方向有无穷多个,不仅仅x、y方向上有变化率,其他方向任意方向上都有变化率
- 梯度是一个向量,它向量的方向是使得方向导数最大的那个方向
梯度的方向为方向导数变化率最大的方向(方向为使得方向导数取得最大值的方向),朝着斜坡最陡峭的地方去。模长即为方向导数的值,即为方向导数的变化率。
梯度就是在当前这个点增长最快的一个方向,模长就是增长的速度。梯度的负方向就是下降最快的。
2、optimizer的基本属性
- defaults里面是基本的超参数:lr、momentum、dampending、weight_decay、nesterov等,用字典打包;
- state没有训练时为空;
- param_groups是一个list,list中每一个元素是一个字典,字典中的'params'才是我们真正的参数,'params'里面的数据又是一个list,list中的一个元素又是一组参数
- params里面的内容为网络的可学习参数,比如,00表示第一个卷积层权重,01表示bias
- _step_count记录更新次数,以便于调整训练策略,例如每更新100次参数,对学习率调整一下
3、optimizer的基本方法
-
zero_grad()
在每一次反向传播过程中,需要对梯度张量进行清零,因为pytorch张量梯度不能自动清零,是累加的
# ----------------------------------- zero_grad -----------------------------------
flag = 0
# flag = 1
if flag:
print("weight before step:{}".format(weight.data))
optimizer.step() # 修改lr=1 0.1观察结果
#更新参数
print("weight after step:{}".format(weight.data))
print("weight in optimizer:{}\nweight in weight:{}\n".format(id(optimizer.param_groups[0]['params'][0]), id(weight)))
print("weight.grad is {}\n".format(weight.grad))
optimizer.zero_grad() #清空梯度
print("after optimizer.zero_grad(), weight.grad is\n{}".format(weight.grad))
-
step()
执行,对参数进行更新
优化器保存的权值地址和真实权值的地址是一样的,说明优化器是通过地址去寻找权值
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
import torch
import torch.optim as optim
import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR)
from tools.common_tools import set_seed
set_seed(1) # 设置随机种子
weight = torch.randn((2, 2), requires_grad=True)#权重初始值设置,把梯度打开
weight.grad = torch.ones((2, 2))#权重初始梯度设置
optimizer = optim.SGD([weight], lr=0.1)#把参数weight放进去,把学习率设置为0.1
#学习率即为一个步长
# ----------------------------------- step -----------------------------------
flag = 0
# flag = 1
if flag:
print("weight before step:{}".format(weight.data))
optimizer.step() # 修改lr=1 0.1观察结果
#step更新参数
print("weight after step:{}".format(weight.data))
-
add_param_group()
通过参数组param_groups的概念来设置不同的参数有不同的学习率,在模型的fineturn中非常实用
# ----------------------------------- add_param_group -----------------------------------
flag = 0
# flag = 1
if flag:
print("optimizer.param_groups is\n{}".format(optimizer.param_groups))
w2 = torch.randn((3, 3), requires_grad=True)
optimizer.add_param_group({"params": w2, 'lr': 0.0001})
#添加参数组
print("optimizer.param_groups is\n{}".format(optimizer.param_groups))
-
state_dict() 与 load_state_dict()
state_dict()和load_state_dict()用于保存模型信息和加载模型信息,用于模型断点的续训练,一般在10个epoch或50个epoch的时候要保存当前的状态信息;
state_dict()返回两个字典,一个state用于保存缓存信息packed_state和一个param_groups;
load_state_dict()用于接收一个state_dict放在优化器中;
# ----------------------------------- state_dict -----------------------------------
flag = 0
# flag = 1
if flag:
optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
opt_state_dict = optimizer.state_dict()
print("state_dict before step:\n", opt_state_dict)
#打印没有执行之前的state_dict
for i in range(10):
optimizer.step()#更新参数
print("state_dict after step:\n", optimizer.state_dict())
#打印更新了10步之后的state_dict
#打印输出:{'state':{1818370992672:{'momentum_buffer':tensor([[6.5132,6.5132
#[6.5132,6.5132]])}},'param_groups':[{'lr':0.1,'momentum':0.9, 'dampending':0 省略...省略 'params':[1818370992672]}]}
#1818370992672就是params的地址,根据地址索引权重等信息
torch.save(optimizer.state_dict(), os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))
#将字典信息保存成.pkl文件,save的第一个参数为获取的state_dict()
# -----------------------------------load state_dict -----------------------------------
flag = 0
# flag = 1
if flag:
#如果想继续训练,则重新构建一个优化器,然后将state_dict的信息添加进来,就可以接着上一次的状态往下训练
optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
state_dict = torch.load(os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))
print("state_dict before load state:\n", optimizer.state_dict())
optimizer.load_state_dict(state_dict)
print("state_dict after load state:\n", optimizer.state_dict())