unexpected key "module.conv1_1.weight" in state_dict

本文解决在PyTorch中加载模型时遇到的KeyError问题,详细分析了错误产生的原因,包括训练与预测阶段GPU使用的不一致,并提供了两种有效的解决方案:一是通过DataParallel模块确保GPU一致性;二是修改state_dict键名,去除不必要的'module.'前缀。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

torch加载模型时出现如下错误

异常位置
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
network.load_state_dict(torch.load(save_path))
异常信息
File "/data/Muyi/Github/EnlightenGAN/models/base_model.py", line 54, in load_network
network.load_state_dict(torch.load(save_path))
File "/usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 522, in load_state_dict
.format(name))
KeyError: 'unexpected key "module.conv1_1.weight" in state_dict'

异常原因
最终原因在此:训练时使用GPU,使用了torch.nn.DataParallel(),而此时预测没有使用GPU,即没有使用此模块导致上述异常
if len(gpu_ids) > 0:
    netG.cuda(device=gpu_ids[0])
    netG = torch.nn.DataParallel(netG, gpu_ids)
解决方案
1):加上torch.nn.DataParallel()模块,类似我的问题只需要使用GPU即可正常运行
model = torch.nn.DataParallel(model)
cudnn.benchmark = True
2):将原来字典中module.删除掉
network.load_state_dict({k.replace('module.',''):v for k,v in torch.load(save_path).items()})
更改原来代码如下,即可在CPU/GPU下都正常运行
if len(self.gpu_ids):
    network.load_state_dict(torch.load(save_path))
else:
    network.load_state_dict({k.replace('module.',''):v for k,v in torch.load(save_path).items()})
RuntimeError: Error(s) in loading state_dict for UNet: Missing key(s) in state_dict: "inc.double_conv.0.weight", "inc.double_conv.2.weight", "down1.maxpool_conv.1.double_conv.0.weight", "down1.maxpool_conv.1.double_conv.2.weight", "down2.maxpool_conv.1.double_conv.0.weight", "down2.maxpool_conv.1.double_conv.2.weight", "down3.maxpool_conv.1.double_conv.0.weight", "down3.maxpool_conv.1.double_conv.2.weight", "deconv.weight", "deconv.bias", "up1.up.weight", "up1.up.bias", "up1.conv.double_conv.0.weight", "up1.conv.double_conv.2.weight", "up2.up.weight", "up2.up.bias", "up2.conv.double_conv.0.weight", "up2.conv.double_conv.2.weight", "up3.up.weight", "up3.up.bias", "up3.conv.double_conv.0.weight", "up3.conv.double_conv.2.weight", "outc.conv.weight", "outc.conv.bias", "CB512.conv_mask.weight", "CB512.conv_mask.bias", "CB512.channel_add_conv.0.weight", "CB512.channel_add_conv.0.bias", "CB512.channel_add_conv.2.weight", "CB512.channel_add_conv.2.bias", "CB256.conv_mask.weight", "CB256.conv_mask.bias", "CB256.channel_add_conv.0.weight", "CB256.channel_add_conv.0.bias", "CB256.channel_add_conv.2.weight", "CB256.channel_add_conv.2.bias", "CB128.conv_mask.weight", "CB128.conv_mask.bias", "CB128.channel_add_conv.0.weight", "CB128.channel_add_conv.0.bias", "CB128.channel_add_conv.2.weight", "CB128.channel_add_conv.2.bias", "CB64.conv_mask.weight", "CB64.conv_mask.bias", "CB64.channel_add_conv.0.weight", "CB64.channel_add_conv.0.bias", "CB64.channel_add_conv.2.weight", "CB64.channel_add_conv.2.bias". Unexpected key(s) in state_dict: "feature_ir.batch_norm1.weight", "feature_ir.batch_norm1.bias", "feature_ir.batch_norm1.running_mean", "feature_ir.batch_norm1.running_var", "feature_ir.batch_norm1.num_batches_tracked", "feature_ir.batch_norm2.weight", "feature_ir.batch_norm2.bias", "feature_ir.batch_norm2.running_mean", "feature_ir.batch_norm2.running_var", "feature_ir.batch_norm2.num_batches_tracked", "feature_ir.batch_norm3.weight", "feature_ir.batch_norm3.bias", "feature_ir.b
最新发布
03-31
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值