1. 报错内容
/opt/tools/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1221
1222 if len(error_msgs) > 0:
-> 1223 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
1224 self.__class__.__name__, "\n\t".join(error_msgs)))
1225 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for UNet:
Missing key(s) in state_dict:................(太多了)
Unexpected key(s) in state_dict: ................(也很多)

2. 解决方案
train的时候我加了数据并行
# train
model = nn.DataParallel(model)
但是test里没加并行,加上就不会报错了
# test
model = UNet(n_channels=3, n_classes=1)
model = nn.Dat

博客讲述了在PyTorch中遇到的模型加载权重时的RuntimeError,原因是模型结构与权重文件键名不匹配。解决办法是在测试阶段同样使用nn.DataParallel。还探讨了nn.DataParallel的作用,它用于模型并行计算,可能导致键名变化。另外,作者建议在保存DataParallel模型时使用model.module.state_dict(),以便在推理时不需再使用DataParallel。
最低0.47元/天 解锁文章
1万+

被折叠的 条评论
为什么被折叠?



