深入理解Pytorch中模型保存文件pth

一、Pytorch中模型保存和加载方法

本文在介绍Pytorch中模型保存文件pth之前,将先探讨如模型的保存/加载的方法。
三个核心函数:

  • torch.save:把序列化的对象保存到硬盘。利用Python的pickle来实现序列化。模型、tensor以及字典都可以用该函数进行保存;
  • torch.load:采用 pickle 将反序列化的对象从存储中加载进来。
  • torch.nn.Module.load_state_dict:采用一个反序列化的state_dict加载一个模型的参数字典。

保存/加载模型

在Pytorch中,模型的保存和加载主要有两种方法,一种是保存/加载整个模型,另一种是只保存/加载模型参数

1. 保存整个模型

这种方法保存和加载模型都是采用最简单的语法。这种方法将是采用Python的pickle模块来保存整个模型,它的缺点就是序列化后的数据是属于特定的类和指定的字典结构,原因就是pickle并没有保存模型类别,而是保存一个包含该类的文件路径,因此,当在其他项目或者在 refactors 后采用都可能出现错误。

示例方法:

# 保存整个模型
torch.save(model, PATH)
# 加载整个模型
model = torch.load(PATH)
model.eval()

造成的影响:

import torch

pthfile = r'your_path/model.pth' 
net = torch.load(pthfile, map_location=torch.device('cpu')) # 加载模型

print(type(net)) # 类型是 dict
print(len(net)) # 长度为 4,即存在四个 key-value 键值对

for k in net.keys():
print(k) # 查看四个键,分别是 model,optimizer,scheduler,iteration

save_all_model.png
如上图所示,保存的完整模型是一个字典,包含了模型、优化器、学习率调整器、迭代次数等信息。这种方法的缺点是,如果要加载模型,必须要保证模型的类别和结构不变,否则会报错。
例如,如果加载了之前的模型,但是想修改学习率之类的参数,由于加载的模型中保存了optimizer和scheduler,所以再次加载此文件时会使用之前的学习率。如果想要修改参数,其实只需要将模型权重加载进来就可以了,不需要再加载optimizer和scheduler。
如下面的代码:

import torch
 
net = torch.load('your_path/model.pth')
 
new = {
   
   "model": net["model"]} # 只保存模型的参数
torch.save(new, 'your_new_path/model.pth')

或者在保存模型时,只保存模型的参数,如后面1.2节所示。

2. 仅保存模型参数(官方推荐)

当需要为预测保存一个模型的时候,只需要保存训练模型的可学习参数即可。
示例代码:

torch.save(model.state_dict()
评论 4
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值