当保存了神经网络模型后,进行训练加载模型的时候一直出错,代码如下:
model = model.load_state_dict(torch.load(path))
运行结果如下,出现:
Error(s) in loading state_dict for ResNet18
出现这样的问题,应该是版本的原因,这样的语句现在新的版本好像不支持了(如果有朋友知道具体的原因,欢迎留言指正哈……)
这里将上面的代码改成如下这样:
model_dict = torch.load(root)
model_dict_clone = model_dict.copy()
for key, value in model_dict_clone.items():
if key.endswith(('running_mean', 'running_var')):
del model_dict[key]
model.load_state_dict(model_dict, False)
还有一种方法是直接在原代码上加上False参数,如下:
model.load_state_dict(torch.load(root),False)
小白学习时间不长,欢迎朋友们评论交流……