pytorch(三)——保存自己的模型

本文详细介绍如何在PyTorch中保存模型参数及整个模型,包括两种保存方法的具体操作步骤,以及可能出现的错误及其解决办法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

pytorch(三)——保存自己的模型

模型依照前一片文章訓練好之後應該怎麼保存呢?

有兩個方法

1.保存模型參數


torch.save(the_model.state_dict(), PATH)


the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

2.保存整個模型(推薦)


torch.save(the_model, PATH)

the_model = torch.load(PATH)

可能出現報錯

AttributeError: Can't get attribute 'Net' on module '__main__'
解決方法:補充上網絡的結構就可以了

測試

model = the_model()
print(model)

如果能打印出網絡結構,則成功了~~

### 如何在 PyTorch保存模型 为了保存训练好的 PyTorch 模型,通常有两种方法:一种是保存整个模型对象;另一种则是只保存模型的状态字典(state dictionary)。这两种方式各有优劣。 #### 方法一:保存完整的模型对象 通过 `torch.save` 函数可以直接序列化并存储整个模型实例到磁盘文件中。这种方式简单直观,适合快速存档或迁移学习场景下的应用[^2]。 ```python import torch model = ... # 定义好待保存的神经网络结构 torch.save(model, 'path_to_saved_model/model.pt') ``` 这种方法的优点在于恢复时只需加载即可立即投入使用,无需重新构建相同的类定义。然而缺点也很明显——它依赖于特定版本的 Python 和 PyTorch 库环境,并且可能因为这些外部因素而难以兼容不同平台上的部署需求。 #### 方法二:仅保存状态字典 更推荐的做法是单独提取出模型参数部分作为独立的数据包来处理。这不仅减少了不必要的冗余信息量,还提高了跨版本间的稳定性以及灵活性[^1]。 ```python # 只保存模型权重而不是完整架构 torch.save(model.state_dict(), 'path_to_saved_weights/weights.pth') # 加载已有的预训练权重至新创建相同配置的对象上 new_model = TheModelClass(*args, **kwargs) new_model.load_state_dict(torch.load('path_to_saved_weights/weights.pth')) new_model.eval() ``` 上述代码片段展示了如何利用 `.state_dict()` 接口获取当前模型内部各层可训练变量的具体数值表示形式,并将其持久化为 .pth 文件格式以便后续读取操作。当需要再次激活该模型执行推理任务前,则应先调用对应的构造函数初始化一个新的同构实体再填充回之前记录下来的权值集合,最后记得切换评估模式以关闭 dropout 等随机机制影响预测效果。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值