一、Pytorch中模型保存和加载方法
本文在介绍Pytorch中模型保存文件pth之前,将先探讨如模型的保存/加载的方法。
三个核心函数:
- torch.save:把序列化的对象保存到硬盘。利用Python的pickle来实现序列化。模型、tensor以及字典都可以用该函数进行保存;
- torch.load:采用 pickle 将反序列化的对象从存储中加载进来。
- torch.nn.Module.load_state_dict:采用一个反序列化的state_dict加载一个模型的参数字典。
保存/加载模型
在Pytorch中,模型的保存和加载主要有两种方法,一种是保存/加载整个模型,另一种是只保存/加载模型参数。
1. 保存整个模型
这种方法保存和加载模型都是采用最简单的语法。这种方法将是采用Python的pickle模块来保存整个模型,它的缺点就是序列化后的数据是属于特定的类和指定的字典结构,原因就是pickle并没有保存模型类别,而是保存一个包含该类的文件路径,因此,当在其他项目或者在 refactors 后采用都可能出现错误。
示例方法:
# 保存整个模型
torch.save(model, PATH)
# 加载整个模型
model = torch.load(PATH)
model.eval()
造成的影响:
import torch
pthfile = r'your_path/model.pth'
net = torch.load(pthfile, map_location=torch.device('cpu')) # 加载模型
print(type(net)) # 类型是 dict
print(len(net)) # 长度为 4,即存在四个 key-value 键值对
for k in net.keys():
print(k) # 查看四个键,分别是 model,optimizer,scheduler,iteration

如上图所示,保存的完整模型是一个字典,包含了模型、优化器、学习率调整器、迭代次数等信息。这种方法的缺点是,如果要加载模型,必须要保证模型的类别和结构不变,否则会报错。
例如,如果加载了之前的模型,但是想修改学习率之类的参数,由于加载的模型中保存了optimizer和scheduler,所以再次加载此文件时会使用之前的学习率。如果想要修改参数,其实只需要将模型权重加载进来就可以了,不需要再加载optimizer和scheduler。
如下面的代码:
import torch
net = torch.load('your_path/model.pth')
new = {
"model": net["model"]} # 只保存模型的参数
torch.save(new, 'your_new_path/model.pth')
或者在保存模型时,只保存模型的参数,如后面1.2节所示。
2. 仅保存模型参数(官方推荐)
当需要为预测保存一个模型的时候,只需要保存训练模型的可学习参数即可。
示例代码:
torch.save(model.state_dict()

最低0.47元/天 解锁文章
1万+





