RuntimeError: Error(s) in loading state_dict for XXX

本文解决了一个常见的PyTorch问题,在使用DataParallel训练模型后进行加载时出现的键缺失和多余问题。通过在加载前对模型应用DataParallel,可以有效避免此问题。
部署运行你感兴趣的模型镜像

用torch保存训练后的模型

torch.save(model.state_dict(), 'file.pkl')

重新导入用于测试

model.load_state_dict(torch.load( 'file.pkl'))
model = model.cuda()

报错:

RuntimeError: Error(s) in loading state_dict for XXX:
	Missing key(s) in state_dict: "conv1.0.conv.weight", "conv1.1.weight", "conv1.1.bias", "conv1.1.running_mean", "conv1.1.running_var", "Block1.0.conv.weight", "Block1.1.weight", "Block1.1.bias", "Block1.1.running_mean", "Block1.1.running_var", "Block2.0.conv.weight", "Block2.1.weight", "Block2.1.bias", "Block2.1.running_mean", "Block2.1.running_var", "Block3.0.conv.weight", "Block3.1.weight", "Block3.1.bias", "Block3.1.running_mean", "Block3.1.running_var", "lastconv1.0.conv.weight", "lastconv1.1.weight", "lastconv1.1.bias", "lastconv1.1.running_mean", "lastconv1.1.running_var", "lastconv1.3.conv.weight", "sa1.conv1.weight", "sa2.conv1.weight", "sa3.conv1.weight". 
	Unexpected key(s) in state_dict: "module.conv1.0.conv.weight", "module.conv1.1.weight", "module.conv1.1.bias", "module.conv1.1.running_mean", "module.conv1.1.running_var", "module.conv1.1.num_batches_tracked", "module.Block1.0.conv.weight", "module.Block1.1.weight", "module.Block1.1.bias", "module.Block1.1.running_mean", "module.Block1.1.running_var", "module.Block1.1.num_batches_tracked", "module.Block2.0.conv.weight", "module.Block2.1.weight", "module.Block2.1.bias", "module.Block2.1.running_mean", "module.Block2.1.running_var", "module.Block2.1.num_batches_tracked", "module.Block3.0.conv.weight", "module.Block3.1.weight", "module.Block3.1.bias", "module.Block3.1.running_mean", "module.Block3.1.running_var", "module.Block3.1.num_batches_tracked", "module.lastconv1.0.conv.weight", "module.lastconv1.1.weight", "module.lastconv1.1.bias", "module.lastconv1.1.running_mean", "module.lastconv1.1.running_var", "module.lastconv1.1.num_batches_tracked", "module.lastconv1.3.conv.weight", "module.sa1.conv1.weight", "module.sa2.conv1.weight", "module.sa3.conv1.weight". 

有很多分享说是版本不同的问题,只要将load_state_dict(state_dict) 改成 model.load_state_dict(state_dict, False),即

model.load_state_dict(torch.load( 'file.pkl'), False)

这样做确实是不会报错了,但是相当于简单粗暴的不管这些不对应的key,最后整个模型都不可用,预测出的数值都非常离谱

琢磨了很久发现原因,训练时使用多GPU并行操作,使用了torch.nn.DataParallel封装

model = torch.nn.DataParallel(model, device_ids=[0,1,2]).cuda()

而重新导入的时候没有用torch.nn.DataParallel操作,那么解决方法非常简单,只要在原有的代码中加一行model=nn.DataParallel(model) 即可,也就是

model = nn.DataParallel(model)
model.load_state_dict(torch.load( 'file.pkl'))
model = model.cuda()

问题解决

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

当MobileNetV2加载state_dict出现RuntimeError时,可根据不同错误类型采用不同解决方法: - **出现unexpected key(s) in state_dict错误**: 若加载模型时出现`RuntimeError: Error(s) in loading state_dict for Net:unexpected key(s) in state_dict: XXX`,可以在加载时设置`strict=False`,示例代码如下: ```python import torch from torchvision.models import mobilenet_v2 model = mobilenet_v2() try: model.load_state_dict(torch.load('models/params.pt')) except RuntimeError as e: if 'unexpected key(s) in state_dict' in str(e): model.load_state_dict(torch.load('models/params.pt'), strict=False) ``` 这种方法可以忽略参数字典中多余的键,避免因多余键导致的加载错误[^1]。 - **出现Missing key(s) in state_dict错误**: 若错误是由于模型训练时使用多张GPU并行训练,使保存的模型参数键值对中键开头多出现了`"module."`字符串,可去除该字符串后再加载。示例代码如下: ```python import torch from torchvision.models import mobilenet_v2 model = mobilenet_v2() checkpoint = torch.load('models/params.pt') new_state_dict = {} for k, v in checkpoint.items(): if k.startswith('module.'): k = k[7:] # 去掉 "module." new_state_dict[k] = v model.load_state_dict(new_state_dict) ``` 此方法可以解决因并行训练导致的键不匹配问题[^3]。 - **出现size mismatch错误**: 若出现`RuntimeError: Error(s) in loading state_dict for Network: size mismatch`,可参考以下修改方法,舍弃部分不匹配的参数: ```python import os import torch from torchvision.models import mobilenet_v2 from termcolor import colored def load_network(net, model_dir, strict=True, map_location=None): if not os.path.exists(model_dir): print(colored('WARNING: NO MODEL LOADED !!!', 'red')) return 0 print('load model: {}'.format(model_dir)) if map_location is None: pretrained_model = torch.load(model_dir, map_location={'cuda:0': 'cpu', 'cuda:1': 'cpu', 'cuda:2': 'cpu', 'cuda:3': 'cpu'}) else: pretrained_model = torch.load(model_dir, map_location=map_location) if 'epoch' in pretrained_model.keys(): epoch = pretrained_model['epoch'] + 1 else: epoch = 0 pretrained_model = pretrained_model['net'] net_weight = net.state_dict() for key in net_weight.keys(): net_weight.update({key: pretrained_model[key]}) # 舍弃部分参数,根据实际情况修改 net_weight.pop("some_layer.weight") net_weight.pop("some_layer.bias") net.load_state_dict(net_weight, strict=strict) return epoch model = mobilenet_v2() load_network(model, 'models/params.pt') ``` 该方法通过舍弃不匹配的参数,使模型能够加载其余匹配的参数[^4]。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值