本文主要介绍如何加载和保存 PyTorch 的模型的区别与联系。这里主要有两个核心函数:
torch.save
:把序列化的对象保存到硬盘。它利用了 Python 的pickle
来实现序列化。模型、张量以及字典都可以用该函数进行保存;torch.load
:采用pickle
将反序列化的对象从存储中加载进来。
-
保存和加载整个模型:
- 保存模型:使用
torch.save()
函数将整个模型保存到文件中。可以指定文件路径和名称以及要保存的对象。torch.save(model, 'model.pth')
- 加载模型:使用
torch.load()
函数加载保存的模型。可以通过指定文件路径和名称来加载模型。model = torch.load('model.pth')
- 保存模型:使用
-
保存和加载模型参数:
- 保存模型参数:使用
state_dict()
方法获取模型的参数字典,然后使用torch.save()
函数将参数字典保存到文件中。torch.save(model.state_dict(), 'model_weights.pth')
- 加载模型参数:首先需要创建一个与原始模型结构相同的实例,然后使用
load_state_dict()
方法加载保存的参数字典。model = ModelClass(*args, **kwargs) # 创建模型实例 model.load_state_dict(torch.load('model_weights.pth'))
- 保存模型参数:使用
请注意,在加载模型或模型参数之前,需要确保模型的定义已经存在,并且与保存时的模型结构相匹配。此外,建议使用相同的设备(CPU或GPU)来加载模型,以确保模型能够正常运行。
3.加载前的模型定义的问题解答:
问题1 :
“请注意,在加载模型或模型参数之前,需要确保模型的定义已经存在,并且与保存时的模型结构相匹配。此外,建议使用相同的设备(CPU或GPU)来加载模型,以确保模型能够正常运行”。这句话针对1的加载方式吗
回答1:
对于第一种保存和加载整个模型的方式,确保模型的定义已经存在并且与保存时的模型结构相匹配非常重要。这是因为torch.save()
函数会将整个模型的结构、参数和其他相关信息都保存到文件中,因此在加载模型时需要确保模型的定义与保存时的一致。同时,建议在加载模型时使用相同的设备(CPU或GPU)来加载模型,以确保模型能够正常运行。如果保存的模型是在GPU上训练的,那么在加载时也需要将模型加载到GPU上。否则,如果设备不匹配可能会导致错误。总而言之,无论是保存和加载整个模型还是保存和加载模型参数,都需要确保模型的定义与保存时的一致,并尽可能在相同的设备上加载模型,以保证模型的正确性和可用性
问题2:
(yolov5_py3.8) hu@gpu003:~/projects/$ python
Python 3.8.13 (default, Mar 28 2022, 11:38:47) >>> import torch
>>> new_m = torch.load('prune.pt')
>>>
为什么这样也能正确加载进来,并没有在之前定义模型呀
回答2:
无论是那两种方式都必须在加载之前定义模型,这句话是正确的,之所以出现我们写代码时,有时未在之前定义模型也能正确加载进来是因为,加载代码所在的位置,或者说路径与保存时的路径一致,模型保存时会将路径也保存进去,并在加载时自动搜索该路径找到模型进行定义。如果换个位置,就不能正确导入模型了