Pytorch | 两种模型的保存与加载之间的区别及联系

本文详细阐述PyTorch中模型的保存和加载,包括`torch.save`和`torch.load`函数的使用,以及如何保存和加载模型参数。强调在加载前必须确保模型定义存在且结构匹配,以及在相同设备上加载的重要性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

 

本文主要介绍如何加载和保存 PyTorch 的模型的区别与联系。这里主要有两个核心函数:

  • torch.save :把序列化的对象保存到硬盘。它利用了 Python 的 pickle 来实现序列化。模型、张量以及字典都可以用该函数进行保存;
  • torch.load:采用 pickle 将反序列化的对象从存储中加载进来。

  1. 保存和加载整个模型:

    • 保存模型:使用torch.save()函数将整个模型保存到文件中。可以指定文件路径和名称以及要保存的对象。
      torch.save(model, 'model.pth')
    • 加载模型:使用torch.load()函数加载保存的模型。可以通过指定文件路径和名称来加载模型。
      model = torch.load('model.pth')
  2. 保存和加载模型参数:

    • 保存模型参数:使用state_dict()方法获取模型的参数字典,然后使用torch.save()函数将参数字典保存到文件中。
      torch.save(model.state_dict(), 'model_weights.pth')
    • 加载模型参数:首先需要创建一个与原始模型结构相同的实例,然后使用load_state_dict()方法加载保存的参数字典。
      model = ModelClass(*args, **kwargs) # 创建模型实例 model.load_state_dict(torch.load('model_weights.pth'))

请注意,在加载模型或模型参数之前,需要确保模型的定义已经存在,并且与保存时的模型结构相匹配。此外,建议使用相同的设备(CPU或GPU)来加载模型,以确保模型能够正常运行。

      3.加载前的模型定义的问题解答:

问题1 :

        “请注意,在加载模型或模型参数之前,需要确保模型的定义已经存在,并且与保存时的模型结构相匹配。此外,建议使用相同的设备(CPU或GPU)来加载模型,以确保模型能够正常运行”。这句话针对1的加载方式吗

回答1:

        对于第一种保存和加载整个模型的方式,确保模型的定义已经存在并且与保存时的模型结构相匹配非常重要。这是因为torch.save()函数会将整个模型的结构、参数和其他相关信息都保存到文件中,因此在加载模型时需要确保模型的定义与保存时的一致。同时,建议在加载模型时使用相同的设备(CPU或GPU)来加载模型,以确保模型能够正常运行。如果保存的模型是在GPU上训练的,那么在加载时也需要将模型加载到GPU上。否则,如果设备不匹配可能会导致错误。总而言之,无论是保存和加载整个模型还是保存和加载模型参数,都需要确保模型的定义与保存时的一致,并尽可能在相同的设备上加载模型,以保证模型的正确性和可用性

问题2:

(yolov5_py3.8) hu@gpu003:~/projects/$ python

Python 3.8.13 (default, Mar 28 2022, 11:38:47) >>> import torch

>>> new_m = torch.load('prune.pt')

>>>

为什么这样也能正确加载进来,并没有在之前定义模型呀

回答2:

        无论是那两种方式都必须在加载之前定义模型,这句话是正确的,之所以出现我们写代码时,有时未在之前定义模型也能正确加载进来是因为,加载代码所在的位置,或者说路径与保存时的路径一致,模型保存时会将路径也保存进去,并在加载时自动搜索该路径找到模型进行定义。如果换个位置,就不能正确导入模型了

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值