在服务器上运行网络时出现错误:
RuntimeError: Error(s) in loading state_dict for ReflectionNetwork:
原来:
checkpoint = torch.load("/Pre_train/checkpoint_81_test1.pth.tar",map_location='cuda:0')
model.load_state_dict(checkpoint['state_dict'])
model.eval()
修改后:
checkpoint = torch.load("/Pre_train/checkpoint_81_test1.pth.tar",map_location='cuda:0')
model.load_state_dict(checkpoint['state_dict'],False)
model.eval()
完结撒花!

博客分享了如何修复服务器上ReflectionNetwork模型加载state_dict时的RuntimeError,通过修改load_state_dict函数参数,成功解决了问题并完成了模型评估。
2018

被折叠的 条评论
为什么被折叠?



