深度解析torch.save 和torch.load的用法

在PyTorch中,torch.save()和torch.load()默认使用Python的pickle模块(二进制序列化和反序列化)进行操作,因此你也可以将多个张量保存为Python对象的一部分,如元组、列表和字典(实际上,凡是能pickle的数据结构,包括自建的数据结构,均可save和load),如:
#保存一个字典:

d = {
   'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
torch.save(d, 'tensor_dict.pt')
torch.load('tensor_dict.pt')
{
   'a': tensor([1., 2.]), 'b': tensor([3., 4.])}

#还可以保存一个列表:

numbers = torch.arange(1, 10)
evens = numbers[1::2]#[1::2]表示从第2个元素(‘1’)开始,按照步长为2的方式取数
torch.save([numbers, evens], 'tensors.pt')
loaded_numbers, loaded_evens = torch.load('tensors.pt')
loaded_evens *= 2
loaded_numbers
tensor([ 1,  4,  3,  8,  5, 12,  7, 16,  9])
  1. 保存和加载模型
    1.1 什么是state_dict?
    在PyTorch中,torch.nn.Module模型的可学习参数(即权重和偏差)包含在模型的参数中(可使用model.parameters()访问)。state_dict只是一个Python字典对象,将每一层映射到其参数张量。请注意,在模型的state_dict中,只有具有可学习参数的层(卷积层、线性层等)和注册的缓冲区(batchnorm的running_mean)才有条目(即为健key)。优化器对象(torch.optim)也有一个state_dict,其中包含有关优化器状态以及使用的超参数的信息。因为state_dict对象是Python字典,所以可以轻松地保存、更新、修改和恢复,为PyTorch模型和优化器增加了大量的模块化操作。例如定义一个模型:
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值