【pytorch】保存模型、加载模型

文章详细介绍了在PyTorch中如何使用torch的save和load函数保存和加载模型及数据,包括state_dict的概念,以及如何保存和加载模型参数和优化器的状态。此外,还展示了如何仅保存和加载模型参数以及整个模型的步骤。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1、torch 的 save 和 load

我们可以直接使用 save 函数 和 load函数 进行存储和读取。

  • save 使用 Python 的 pickle 实用程序将对象进行序列化,然后将序列化的对象保存到disk。save可以保存各种对象,包括模型、张量和字典等。
  • load 使用 pickle unpickle 工具将 pickle 的对象文件反序列化为内存
import torch

x = [1, 2]
y = {'name':'xiaoming', 'age':16}
z = (x, y)

torch.save(z, 'z.pt')

z_new = torch.load('z.pt')
print(z_new)   # ([1, 2], {'name': 'xiaoming', 'age': 16})

2、state_dict

1)net.state_dict()

PyTorch中,Module 的可学习参数 (即权重和偏差),模块模型包含在参数中 (通过 model.parameters() 访问)。state_dict 是一个从参数名称隐射到参数 Tesnor 的有序字典对象。
注意,只有具有可学习参数的层(卷积层、线性层等) 才有 state_dict 中的条目。

import torch.nn as nn

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(3, 2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2, 1)

    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)

net = MLP()
print(net.state_dict())

# OrderedDict([('hidden.weight', tensor([[-0.2360, -0.3193, -0.2618],[ 0.1759, -0.0888,  0.2635]])), 
#              ('hidden.bias', tensor([ 0.2161, -0.3944])), 
#              ('output.weight', tensor([[-0.5358, -0.2140]])), 
#              ('output.bias', tensor([0.6262]))])

2)optimizer.state_dict()

优化器(optim) 也有一个 state_dict,其中包含关于优化器状态以及所使用的超参数的信息。

import torch
import torch.nn as nn


net = nn.Sequential(nn.Linear(3, 2), nn.ReLU(), nn.Linear(2, 1))
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

print(optimizer.state_dict())
# {'state': {}, 
#  'param_groups': [{'lr': 0.001, 
#                    'momentum': 0.9, 
#                    'dampening': 0, 
#                    'weight_decay': 0, 
#                    'nesterov': False, 
#                    'maximize': False, 
#                    'foreach': None, 
#                    'params': [0, 1, 2, 3]}]}


3、保存模型 和 加载模型

1)仅保存和加载模型参数(state_dict)

import torch
import torch.nn as nn

net = nn.Sequential(nn.Linear(3, 2), nn.ReLU(), nn.Linear(2, 1))

# 保存模型参数
torch.save(net.state_dict(), 'model_weight.pt')   # 推荐的文件后缀名是pt或pth

# 下载模型参数
model_weight = torch.load('model_weight.pt')
print(model_weight)
# OrderedDict([('0.weight', tensor([[-0.3865, -0.4623,  0.1212],[-0.2480,  0.3840,  0.1916]])), 
#              ('0.bias', tensor([ 0.0698, -0.0641])), 
#              ('2.weight', tensor([[-0.1499, -0.2895]])), 
#              ('2.bias', tensor([0.2585]))])

# 下载模型参数 并放到模型中
net.load_state_dict(torch.load('model_weight.pt'))

2)保存 和 加载 整个模型

import torch
import torch.nn as nn

net = nn.Sequential(nn.Linear(3, 2), nn.ReLU(), nn.Linear(2, 1))

# 保存整个模型
torch.save(net, 'model.pt')

# 下载模型参数 并放到模型中
net_new = torch.load('model.pt')
print(net_new)
# Sequential(
#   (0): Linear(in_features=3, out_features=2, bias=True)
#   (1): ReLU()
#   (2): Linear(in_features=2, out_features=1, bias=True)
# )
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Enzo 想砸电脑

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值