pytorch之保存模型训练好的参数状态以及直接加载该参数状态来进行预测

本文介绍了如何使用PyTorch框架中的torch.save()函数保存模型参数(state_dict)到model.pkl文件,以及如何使用torch.load()加载已训练模型进行预测。详细展示了在train.py和predict.py文件中的操作过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

使用 PyTorch 框架中的 torch.save() 函数将模型的参数保存到名为 model.pkl 的文件中。

使用 PyTorch 框架中的 torch.load() 函数从文件model.pkl 中加载已经训练好的模型的状态字典。

在 PyTorch 中,模型的参数通常存储在一个名为“state_dict”的字典对象中。这个字典对象包含了模型中所有可学习的参数及其对应的张量。这样,就可以将模型的参数以二进制格式保存到磁盘上。

1. 训练模型train.py

import torch
import torch.nn as nn

# 定义神经网络模型
class MyModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

# 创建模型实例
model = MyModel(input_size=10, hidden_size=20, output_size=2)

# 保存模型状态字典
torch.save(model.state_dict(), 'model.pkl')

2. 加载已经训练好的模型的参数并直接进行预测predict.py


# 加载模型状态字典
model.load_state_dict(torch.load('model.pkl', map_location=torch.device('cpu')))

# 使用模型进行预测
input_tensor = torch.randn(1, 10)
output_tensor = model(input_tensor)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值