Pytorch学习(十七)--- 模型load各种问题解决

简单的模型load

一般来说,保存模型是把参数全部用model.cpu().state_dict(), 然后加载模型时一般用 model.load_state_dict(torch.load(model_path))。 值得注意的是:torch.load 返回的是一个 OrderedDict.

import torch
import torch.nn as nn

class Net_old(nn.Module):
    def __init__(self):
        super(Net_old, self).__init__()
        self.nets = nn.Sequential(
            torch.nn.Conv2d(1, 2, 3),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(2, 1, 3),
            torch.nn.ReLU(True),
            torch.nn.Conv2d(1, 1, 3)
        )
    def forward(self, x):
        return self.nets(x)

class Net_new(nn.Module):
    def __init__(self):
        super(Net_old, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 2, 3)
        self.r1 = torch.nn.ReLU(True)
        self.conv2 = torch.nn.Conv2d(2, 1, 3)
        self.r2 = torch.nn.ReLU(True)
        self.conv3 = torch.nn.Conv2d(1, 1, 3)

    def forward(self, x):
        x = self.conv1(x)
        x = self.r1(x)
        x = self.conv2(x)
        x = self.r2(x)
        x = self.conv3(x)
        return x

network = Net_old()
torch.save(network.cpu().state_dict(), 't.pth')

pretrained_net = torch.load('t.pth')
print(pretrained_net)

for key, v in enumerate(pretrained_net):
    print key, v

可以看到

OrderedDict([('nets.0.weight',
(0 ,0 ,.,.) =
 -0.2436  0.2523  0.3097
 -0.0315 -0.1307  0.0759
  0.0750  0.1894 -0.0761

(1 ,0 ,.,.) =
  0.0280 -0.2178  0.0914
  0.3227 -0.0121 -0.0016
 -0.0654 -0.0584 -0.1655
[torch.FloatTensor of size 2x1x3x3]
), ('nets.0.bias',
-0.0507
-0.2836
[torch.FloatTensor of size 2]
), ('nets.2.weight',
(0 ,0 ,.,.) =
 -0.2233  0.0279 -0.0511
 -0.0242 -0.1240 -0.0511
  0.2266  0.1385 -0.1070

(0 ,1 ,.,.) =
 -0.0943 -0.1403  0.0979
 -0.2163  0.1906 -0.2269
 -0.1984  0.0843 -0.0719
[torch.FloatTensor of size 1x2x3x3]
), ('nets.2.bias',
-0.1420
[torch.FloatTensor of size 1]
), ('nets.4.weight',
(0 ,0 ,.,.) =
  0.1981 -0.0250  0.2429
  0.3012  0.2428 -0.0114
  0.2878 -0.2134  0.1173
[torch.FloatTensor of size 1x1x3x3]
), ('nets.4.bias',
1.00000e-02 *
 -5.8426
[torch.FloatTensor of size 1]
)])
0 nets.0.weight
1 nets.0.bias
2 nets.2.weight
3 nets.2.bias
4 nets.4.weight
5 nets.4.bias

说明.state_dict()只是把所有模型的参数都以OrderedDict</

评论 18
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值