简单的模型load
一般来说,保存模型是把参数全部用model.cpu().state_dict()
, 然后加载模型时一般用 model.load_state_dict(torch.load(model_path))
。 值得注意的是:torch.load
返回的是一个 OrderedDict.
import torch
import torch.nn as nn
class Net_old(nn.Module):
def __init__(self):
super(Net_old, self).__init__()
self.nets = nn.Sequential(
torch.nn.Conv2d(1, 2, 3),
torch.nn.ReLU(True),
torch.nn.Conv2d(2, 1, 3),
torch.nn.ReLU(True),
torch.nn.Conv2d(1, 1, 3)
)
def forward(self, x):
return self.nets(x)
class Net_new(nn.Module):
def __init__(self):
super(Net_old, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 2, 3)
self.r1 = torch.nn.ReLU(True)
self.conv2 = torch.nn.Conv2d(2, 1, 3)
self.r2 = torch.nn.ReLU(True)
self.conv3 = torch.nn.Conv2d(1, 1, 3)
def forward(self, x):
x = self.conv1(x)
x = self.r1(x)
x = self.conv2(x)
x = self.r2(x)
x = self.conv3(x)
return x
network = Net_old()
torch.save(network.cpu().state_dict(), 't.pth')
pretrained_net = torch.load('t.pth')
print(pretrained_net)
for key, v in enumerate(pretrained_net):
print key, v
可以看到
OrderedDict([('nets.0.weight',
(0 ,0 ,.,.) =
-0.2436 0.2523 0.3097
-0.0315 -0.1307 0.0759
0.0750 0.1894 -0.0761
(1 ,0 ,.,.) =
0.0280 -0.2178 0.0914
0.3227 -0.0121 -0.0016
-0.0654 -0.0584 -0.1655
[torch.FloatTensor of size 2x1x3x3]
), ('nets.0.bias',
-0.0507
-0.2836
[torch.FloatTensor of size 2]
), ('nets.2.weight',
(0 ,0 ,.,.) =
-0.2233 0.0279 -0.0511
-0.0242 -0.1240 -0.0511
0.2266 0.1385 -0.1070
(0 ,1 ,.,.) =
-0.0943 -0.1403 0.0979
-0.2163 0.1906 -0.2269
-0.1984 0.0843 -0.0719
[torch.FloatTensor of size 1x2x3x3]
), ('nets.2.bias',
-0.1420
[torch.FloatTensor of size 1]
), ('nets.4.weight',
(0 ,0 ,.,.) =
0.1981 -0.0250 0.2429
0.3012 0.2428 -0.0114
0.2878 -0.2134 0.1173
[torch.FloatTensor of size 1x1x3x3]
), ('nets.4.bias',
1.00000e-02 *
-5.8426
[torch.FloatTensor of size 1]
)])
0 nets.0.weight
1 nets.0.bias
2 nets.2.weight
3 nets.2.bias
4 nets.4.weight
5 nets.4.bias
说明.state_dict()
只是把所有模型的参数都以OrderedDict</