1.在载入模型参数前加上:
model = nn.DataParallel(model)
比如我的:
model_effi7 = torch.nn.DataParallel(model_effi7)
model_effi7.load_state_dict(torch.load(model_path_effi7))
若再出现:RuntimeError:Error(s) in loading state_dict for DataParallel:
则如此修改,从属性state_dict里面复制参数到这个模块和它的后代。如果strict为True, state_dict的keys必须完全与这个模块的方法返回的keys相匹配。如果为False,就不需要保证匹配。
model_effi7.load_state_dict(torch.load(model_path_effi7),False)

本文介绍如何在遇到RuntimeError:Error(s) in loading state_dict for DataParallel时,正确地将模型参数从单独文件迁移到DataParallel模块,并设置strict=False以适应不完全匹配的情况。
2018

被折叠的 条评论
为什么被折叠?



