Pytorch入门(4)—— Tensor和Module的保存与加载

  • 参考:动手学深度学习
  • 注意:由于本文是jupyter文档转换来的,代码不一定可以直接运行,有些注释是jupyter给出的交互结果,而非运行结果!!

1. 读写 Tensor

  • pytorch 提供了 torch.save 函数和 torch.load 函数,分别用于存储和读取 Tensor,其中

    1. torch.save 使用 Python 的 pickle 持续化模块将对象进行序列化并保存到本地磁盘,torch.save 可以保存各种对象,包括模型、张量和字典等
    2. torch.load 使用 pickle unpickle 工具将 pickle 的对象文件反序列化为内存变量
  • 下面创建 Tensor 变量 x,并将保存为本地文件 x.pt,然后再读取回来

    import torch
    from torch import nn
    
    x = torch.ones(3)
    torch.save(x, 'x.pt')
    
    x2 = torch.load('x.pt')
    print(x2) # tensor([1., 1., 1.])
    
  • 类似地,存储Tensor列表和字典并读回内存

    x = torch.ones(2)
    y = torch.zeros(4)
    torch.save([x, y], 'xy_list.pt')
    xy_list = torch.load('xy_list.pt')
    print(xy_list)
    
    torch.save({'x': x, 'y': y}, 'xy_dict.pt')
    xy_dict = torch.load('xy_dict.pt')
    print(xy_dict)
    
    [tensor([1., 1.]), tensor([0., 0., 0., 0.])]
    {'x': tensor([1., 1.]), 'y': tensor([0., 0., 0., 0.])}
    

2. 读写 Module

2.1 state_dict

  • 保存/加载 Module 的一个思路是保存和加载其所有参数。PyTorch 中 Module 的可学习参数包括权重 weight 和偏置 bias,它们可以通过 .parameters().named_parameter() 方法访问

  • 调用 Module 的 .state_dict() 方法,返回一个从参数名称(“layer名.weight” 或 “layer名.bias”)映射到到参数 Tesnor 的字典对象,其中包含了模型的所有可学习参数

    class MLP(nn.Module):
        def __init__(self):
            super(MLP, self).__init__()
            self.hidden = nn.Linear(3, 2)
            self.act = nn.ReLU()
            self.output = nn.Linear(2, 1)
    
        def forward(self, x):
            a = self.act(self.hidden(x))
            return self.output(a)
    
    net = MLP()
    net.state_dict()
    
    OrderedDict([('hidden.weight',
                  tensor([[-0.0749,  0.2594, -0.5274],
                          [ 0.3983, -0.2925,  0.5102]])),
                 ('hidden.bias', tensor([-0.1095, -0.4895])),
                 ('output.weight', tensor([[ 0.2644, -0.3989]])),
                 ('output.bias', tensor([-0.2737]))])
    

    可见,只有具有可学习参数的层(卷积层、线性层等)才有 state_dict 中的条目;另外,优化器(optim)也有一个 state_dict,其中包含关于优化器状态以及所使用的超参数的信息

    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    optimizer.state_dict()
    
    {'state': {},
     'param_groups': [{'lr': 0.001,
       'momentum': 0.9,
       'dampening': 0,
       'weight_decay': 0,
       'nesterov': False,
       'params': [0, 1, 2, 3]}]}
    

2.2 保存和加载模型

  • PyTorch 中保存和加载训练模型有两种常见的方法:
    1. 仅保存和加载模型参数(state_dict),这是推荐方式
    2. 保存和加载整个模型

2.2.1 保存和加载 state_dict(推荐)

  • 通过保存和加载参数来实现模型的存取,相比直接保存整个 Module 对象轻量很多

    1. 保存时,使用 torch.save 保存模型的 state_dict 字典实例
    2. 加载时,先使用 torch.load 加载模型的 state_dict,然后实例化一个所需类型的 Module,最后调用 Moudule 的 .load_state_dict 方法加载参数
  • 给出一个读写多层感知机模型的示例

    # 原始模型
    print('原始模型')
    net = MLP()
    print(net,'\n')
    for name, param in net.named_parameters():
        print(name, param)
    
    # 保存 & 加载
    print('\n\n保存再加载模型')
    torch.save(net.state_dict(), 'mlp.pt') # 推荐的文件后缀名是pt或pth
    model = MLP()
    model.load_state_dict(torch.load('mlp.pt'))
    
    print(model,'\n')
    for name, param in model.named_parameters():
        print(name, param)
    
    原始模型
    MLP(
      (hidden): Linear(in_features=3, out_features=2, bias=True)
      (act): ReLU()
      (output): Linear(in_features=2, out_features=1, bias=True)
    ) 
    
    hidden.weight Parameter containing:
    tensor([[ 0.3508,  0.5170,  0.2746],
            [-0.0571, -0.1772,  0.2815]], requires_grad=True)
    hidden.bias Parameter containing:
    tensor([0.1041, 0.1207], requires_grad=True)
    output.weight Parameter containing:
    tensor([[-0.4299, -0.2678]], requires_grad=True)
    output.bias Parameter containing:
    tensor([-0.3851], requires_grad=True)
    
    
    保存再加载模型
    MLP(
      (hidden): Linear(in_features=3, out_features=2, bias=True)
      (act): ReLU()
      (output): Linear(in_features=2, out_features=1, bias=True)
    ) 
    
    hidden.weight Parameter containing:
    tensor([[ 0.3508,  0.5170,  0.2746],
            [-0.0571, -0.1772,  0.2815]], requires_grad=True)
    hidden.bias Parameter containing:
    tensor([0.1041, 0.1207], requires_grad=True)
    output.weight Parameter containing:
    tensor([[-0.4299, -0.2678]], requires_grad=True)
    output.bias Parameter containing:
    tensor([-0.3851], requires_grad=True)
    

2.2.2 保存和加载整个模型

  • 第 1 节说明了 torch.save 可以保存各种对象,包括模型、张量和字典等 ,所以也可以直接对 Module 实例使用 torch.savetorch.load 进行保存加载

    # 原始模型
    print('原始模型')
    net = MLP()
    print(net,'\n')
    for name, param in net.named_parameters():
        print(name, param)
    
    # 保存 & 加载
    print('\n\n保存再加载模型')
    torch.save(net, 'mlp.pt') # 推荐的文件后缀名是pt或pth
    model = torch.load('mlp.pt')
    
    print(model,'\n')
    for name, param in model.named_parameters():
        print(name, param)
    
    原始模型
    MLP(
      (hidden): Linear(in_features=3, out_features=2, bias=True)
      (act): ReLU()
      (output): Linear(in_features=2, out_features=1, bias=True)
    ) 
    
    hidden.weight Parameter containing:
    tensor([[ 0.1274, -0.4678, -0.5102],
            [ 0.3547,  0.3654, -0.2288]], requires_grad=True)
    hidden.bias Parameter containing:
    tensor([-0.3877,  0.0637], requires_grad=True)
    output.weight Parameter containing:
    tensor([[0.5678, 0.6307]], requires_grad=True)
    output.bias Parameter containing:
    tensor([0.0517], requires_grad=True)
    
    
    保存再加载模型
    MLP(
      (hidden): Linear(in_features=3, out_features=2, bias=True)
      (act): ReLU()
      (output): Linear(in_features=2, out_features=1, bias=True)
    ) 
    
    hidden.weight Parameter containing:
    tensor([[ 0.1274, -0.4678, -0.5102],
            [ 0.3547,  0.3654, -0.2288]], requires_grad=True)
    hidden.bias Parameter containing:
    tensor([-0.3877,  0.0637], requires_grad=True)
    output.weight Parameter containing:
    tensor([[0.5678, 0.6307]], requires_grad=True)
    output.bias Parameter containing:
    tensor([0.0517], requires_grad=True)
    
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

云端FFF

所有博文免费阅读,求打赏鼓励~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值