深入理解Pytorch中模型保存文件pth

一、Pytorch中模型保存和加载方法

本文在介绍Pytorch中模型保存文件pth之前,将先探讨如模型的保存/加载的方法。
三个核心函数:

  • torch.save:把序列化的对象保存到硬盘。利用Python的pickle来实现序列化。模型、tensor以及字典都可以用该函数进行保存;
  • torch.load:采用 pickle 将反序列化的对象从存储中加载进来。
  • torch.nn.Module.load_state_dict:采用一个反序列化的state_dict加载一个模型的参数字典。

保存/加载模型

在Pytorch中,模型的保存和加载主要有两种方法,一种是保存/加载整个模型,另一种是只保存/加载模型参数

1. 保存整个模型

这种方法保存和加载模型都是采用最简单的语法。这种方法将是采用Python的pickle模块来保存整个模型,它的缺点就是序列化后的数据是属于特定的类和指定的字典结构,原因就是pickle并没有保存模型类别,而是保存一个包含该类的文件路径,因此,当在其他项目或者在 refactors 后采用都可能出现错误。

示例方法:

# 保存整个模型
torch.save(model, PATH)
# 加载整个模型
model = torch.load(PATH)
model.eval()

造成的影响:

import torch

pthfile = r'your_path/model.pth' 
net = torch.load(pthfile, map_location=torch.device('cpu')) # 加载模型

print(type(net)) # 类型是 dict
print(len(net)) # 长度为 4,即存在四个 key-value 键值对

for k in net.keys():
print(k) # 查看四个键,分别是 model,optimizer,scheduler,iteration

save_all_model.png
如上图所示,保存的完整模型是一个字典,包含了模型、优化器、学习率调整器、迭代次数等信息。这种方法的缺点是,如果要加载模型,必须要保证模型的类别和结构不变,否则会报错。
例如,如果加载了之前的模型,但是想修改学习率之类的参数,由于加载的模型中保存了optimizer和scheduler,所以再次加载此文件时会使用之前的学习率。如果想要修改参数,其实只需要将模型权重加载进来就可以了,不需要再加载optimizer和scheduler。
如下面的代码:

import torch
 
net = torch.load('your_path/model.pth')
 
new = {
   
   "model": net["model"]} # 只保存模型的参数
torch.save(new, 'your_new_path/model.pth')

或者在保存模型时,只保存模型的参数,如后面1.2节所示。

2. 仅保存模型参数(官方推荐)

当需要为预测保存一个模型的时候,只需要保存训练模型的可学习参数即可。
示例代码:

torch.save(model.state_dict()
<think>好的,我现在需要解决用户的问题:如何在PyTorch中从.pth文件加载模型缓冲区。根据用户提供的引用内容,我需要先理解问题所在以及相关的解决方法。 首先,用户提到在加载自定义网络权重时遇到了错误:AttributeError: 'dict' object has no attribute 'seek'。根据引用[1],这个错误是因为尝试从无法寻址的文件对象加载模型,而解决方案是使用像io.BytesIO这样的缓冲区预先加载数据。用户可能保存模型时使用了不正确的方法,比如保存的是字典而不是模型的状态字典。引用[2]指出,.pth文件可能包含多个键值对,比如model、optimizer等,所以需要正确访问模型的状态。 接下来,我需要回顾PyTorch模型保存和加载的标准方法。通常,使用torch.save(model.state_dict(), 'model.pth')来保存模型的状态字典,然后通过model.load_state_dict(torch.load('model.pth'))来加载。如果用户保存的是整个模型或字典,加载时需要正确处理这些结构。 根据引用[3],torch.load()函数允许通过map_location参数指定加载设备。如果用户的.pth文件包含状态字典,正确的方法应该是加载后使用load_state_dict。如果文件包含其他键,比如优化器的状态,可能需要分别处理。 用户的问题是关于加载模型缓冲区的,这可能涉及模型中的buffer(如BatchNorm的running_mean)。这些buffer在状态字典中会自动保存和加载,因此只要正确加载状态字典,缓冲区应该会被正确恢复。 总结步骤:首先确认.pth文件的结构,如果是包含'model'键的字典,需要提取该部分作为状态字典;然后确保模型实例化后,用load_state_dict加载该状态字典。同时,使用正确的设备映射,如map_location='cpu',避免GPU相关的问题。 此外,用户可能需要检查保存模型时代码是否正确,是否只保存模型的状态而非整个模型对象。如果保存的是包含其他信息的字典,加载时需要适当处理,例如net = torch.load(...)然后model.load_state_dict(net['model'])。 最后,生成相关问题,帮助用户深入理解模型加载的细节和可能遇到的常见问题。</think>在PyTorch中从`.pth`文件加载模型缓冲区需遵循以下步骤: ### 1. 理解.pth文件结构 `.pth`文件通常包含: - **模型参数**(`state_dict`):包含可学习参数(如权重、偏置) - **缓冲区**(buffer):如BatchNorm层的`running_mean`、`running_var` - 其他训练状态(如优化器参数)[^2] ### 2. 标准加载方法 ```python import torch from io import BytesIO # 从文件加载到缓冲区 with open("model.pth", "rb") as f: buffer = BytesIO(f.read()) # 加载模型 state_dict = torch.load(buffer, map_location="cpu") # 指定设备 model = YourModelClass() # 实例化模型 model.load_state_dict(state_dict) ``` *关键点*:通过`BytesIO`创建可寻址的缓冲区对象,避免`'dict' object has no attribute 'seek'`错误[^1] ### 3. 处理复杂文件结构 若`.pth`文件包含多个键值对(如同时保存模型、优化器): ```python checkpoint = torch.load("model.pth", map_location="cpu") model.load_state_dict(checkpoint["model"]) # 提取模型参数 optimizer.load_state_dict(checkpoint["optimizer"]) # 提取优化器状态 ``` ### 4. 缓冲区自动加载原理 PyTorch的`state_dict`自动包含: - 通过`register_buffer()`注册的持久缓冲区 - 非持久缓冲区需手动处理 ```python class YourModel(nn.Module): def __init__(self): super().__init__() self.register_buffer("custom_buffer", torch.zeros(10)) # 自动保存/加载 ``` ### 5. 验证缓冲区加载 ```python print(model.state_dict().keys()) # 应包含所有参数和注册的缓冲区 print(model.custom_buffer) # 检查缓冲区值是否加载成功 ``` ### 6. 常见问题解决 **问题1:KeyError** 原因:模型结构与保存的`state_dict`不匹配 解决方案: ```python model = YourModel() model.load_state_dict(state_dict, strict=False) # 忽略不匹配键 ``` **问题2:设备不匹配** 通过`map_location`参数指定加载位置: ```python # 强制加载到CPU(即使原模型在GPU) state_dict = torch.load("model.pth", map_location=torch.device("cpu")) ```
评论 4
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值