代码如下:
def load_ae(self, network):
pdb.set_trace()
save_filename = 'latest_net_AE.pth'
save_path = os.path.join(self.save_dir, save_filename)
network.load_state_dict(torch.load(save_path))
报错:
KeyError: 'unexpected key "model.model.1.model.3.model.3.model.3.model.3.model.3.model.2.running_mean" in state_dict'
按照网上的解决方案(https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/3), 解释的报错原因是保存模型时使用了nn.DataParallel,导致存储参数里多了.model;但load的时候没有使用nn.DataParallel。
"You probably saved the model using
nn.DataParallel
, which stores the model inmodule
, and now you are trying to load it withoutDataParallel
. You can either add ann.DataParallel
temporarily in your network for loading purposes, or you can load the weights file, create a new ordered dict without themodule
prefix, and load it back."
并根据上述网址的答案将代码改为:
def load_ae(self, network):
pdb.set_trace()
save_filename = 'latest_net_AE.pth'
save_path = os.path.join(self.save_dir, save_filename)
# network.load_state_dict(torch.load(save_path))
# original saved file with DataParallel
state_dict = torch.load(save_path)
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace(".module","") # remove `module.`
new_state_dict[name] = v
# load params
network.load_state_dict(new_state_dict)
然而,还是会报错:
File "/share2/home/ruixu/DeepLearningLCT/DeblurGAN-master/AEOT_Unet/models/base_model.py", line 70, in load_ae
name = k.replace(".module","") # remove `module.`
File "/share2/home/ruixu/anaconda3/envs/python36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 522, in load_state_dict
.format(name))
KeyError: 'unexpected key "model.model.1.model.3.model.3.model.3.model.3.model.3.model.2.running_mean" in state_dict'
万万没想到错误竟然会一模一样!但报错的位置竟然是name = k.replace(".module","") # remove `module.`。
还是一步一步来解决吧。
首先,输出network.state_dict()
... ....
('model.model.3.weight',
( 0 , 0 ,.,.) =
1.00000e-02 *
-1.2084 -2.2976 -2.8284 -1.5688
3.0504 -0.1078 -1.6116 -6.6927
-1.2499 1.8116 0.3322 1.9724
-2.9073 4.4396 2.2341 -0.2121
(127, 2 ,.,.) =
1.00000e-02 *
3.9328 -0.4361 1.6365 -3.3506
-0.8479 -1.9702 -2.2223 1.9633
2.9766 -0.4350 -0.2136 -2.5228
2.0793 -3.1964 -0.3516 3.7284
[torch.cuda.FloatTensor of size 128x3x4x4 (GPU 3)]
), ('model.model.3.bias',
0
0
0
[torch.cuda.FloatTensor of size 3 (GPU 3)]
)])
接下来,load训练好的权重后将其输出:
... ....
('model.model.3.weight',
( 0 , 0 ,.,.) =
1.00000e-02 *
-4.7173 -1.7712 0.5993 1.4681
2.1384 1.6073 2.7149 -4.4520
-0.9415 -1.0043 -2.8195 1.4755
-1.7620 2.9188 -0.9800 -1.5442...
(127, 2 ,.,.) =
1.00000e-02 *
0.5911 1.2228 -4.6552 -0.9199
0.1930 0.9382 -0.1070 -0.5980
0.5762 2.1807 3.0031 -1.0809
0.2351 -0.3915 0.1441 -1.6821
[torch.FloatTensor of size 128x3x4x4]
), ('model.model.3.bias',
1.00000e-03 *
-1.8013
-1.7175
-5.1277
[torch.FloatTensor of size 3]
)])
仅从最后两个元素来看一摸一样!那么问题是出在哪里呢?
根据报错内容KeyError: 'unexpected key "model.model.1.model.3.model.3.model.3.model.3.model.3.model.2.running_mean" in state_dict',那就看看两个state_dict的所有key分别是什么:
先输出要load的训练好的网络的权重state_weight的所有keys:
(Pdb) print('\n'.join(map(str,sorted(state_dict.keys()))))
- model.model.0.bias
- model.model.0.weight
- model.model.1.model.1.bias
- model.model.1.model.1.weight
- model.model.1.model.2.running_mean
- model.model.1.model.2.running_var
- model.model.1.model.3.model.1.bias
- model.model.1.model.3.model.1.weight
- model.model.1.model.3.model.2.running_mean
- model.model.1.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.2.running_mean
- model.model.1.model.3.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.3.model.2.running_mean
- model.model.1.model.3.model.3.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.3.model.3.model.2.running_mean
- model.model.1.model.3.model.3.model.3.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.2.running_mean
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.3.bias
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.3.weight
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.4.running_mean
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.model.4.running_var
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.5.bias
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.5.weight
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.6.running_mean
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.6.running_var
- model.model.1.model.3.model.3.model.3.model.3.model.5.bias
- model.model.1.model.3.model.3.model.3.model.3.model.5.weight
- model.model.1.model.3.model.3.model.3.model.3.model.6.running_mean
- model.model.1.model.3.model.3.model.3.model.3.model.6.running_var
- model.model.1.model.3.model.3.model.3.model.5.bias
- model.model.1.model.3.model.3.model.3.model.5.weight
- model.model.1.model.3.model.3.model.3.model.6.running_mean
- model.model.1.model.3.model.3.model.3.model.6.running_var
- model.model.1.model.3.model.3.model.5.bias
- model.model.1.model.3.model.3.model.5.weight
- model.model.1.model.3.model.3.model.6.running_mean
- model.model.1.model.3.model.3.model.6.running_var
- model.model.1.model.3.model.5.bias
- model.model.1.model.3.model.5.weight
- model.model.1.model.3.model.6.running_mean
- model.model.1.model.3.model.6.running_var
- model.model.1.model.5.bias
- model.model.1.model.5.weight
- model.model.1.model.6.running_mean
- model.model.1.model.6.running_var
- model.model.3.bias
- model.model.3.weight
再输出netowrk的所有初始权重:
(Pdb) print('\n'.join(map(str,sorted(network.state_dict().keys()))))
- model.model.0.bias
- model.model.0.weight
- model.model.1.model.1.bias
- model.model.1.model.1.weight
- model.model.1.model.2.running_mean
- model.model.1.model.2.running_var
- model.model.1.model.3.model.1.bias
- model.model.1.model.3.model.1.weight
- model.model.1.model.3.model.2.running_mean
- model.model.1.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.2.running_mean
- model.model.1.model.3.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.3.model.2.running_mean
- model.model.1.model.3.model.3.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.3.model.3.model.2.running_mean
- model.model.1.model.3.model.3.model.3.model.3.model.2.running_var
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.bias
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.1.weight
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.bias
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.3.weight
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.4.running_mean
- model.model.1.model.3.model.3.model.3.model.3.model.3.model.4.running_var
- model.model.1.model.3.model.3.model.3.model.3.model.5.bias
- model.model.1.model.3.model.3.model.3.model.3.model.5.weight
- model.model.1.model.3.model.3.model.3.model.3.model.6.running_mean
- model.model.1.model.3.model.3.model.3.model.3.model.6.running_var
- model.model.1.model.3.model.3.model.3.model.5.bias
- model.model.1.model.3.model.3.model.3.model.5.weight
- model.model.1.model.3.model.3.model.3.model.6.running_mean
- model.model.1.model.3.model.3.model.3.model.6.running_var
- model.model.1.model.3.model.3.model.5.bias
- model.model.1.model.3.model.3.model.5.weight
- model.model.1.model.3.model.3.model.6.running_mean
- model.model.1.model.3.model.3.model.6.running_var
- model.model.1.model.3.model.5.bias
- model.model.1.model.3.model.5.weight
- model.model.1.model.3.model.6.running_mean
- model.model.1.model.3.model.6.running_var
- model.model.1.model.5.bias
- model.model.1.model.5.weight
- model.model.1.model.6.running_mean
- model.model.1.model.6.running_var
- model.model.3.bias
- model.model.3.weight
一个50一个58!!竟然不一样!
经过进一步检查代码就发现了问题:
竟然是因为权重用的是downsample*256的UNet而网络是downsample*128的UNet........
把这个问题改了问题成功解决。
总结
虽然最后是一个乌龙,但发现这个乌龙的过程还是颇有意义的。