pytorch Error(s) in loading state_dict 使用re模块 改变字典索引

本文详细解析了PyTorch中model.load_state_dict方法的使用,对比Keras的load_weights,介绍了如何解决因模型和权重文件命名不一致导致的加载错误,通过正则表达式修改字典索引,确保DenseNet等复杂网络结构的正确加载。
部署运行你感兴趣的模型镜像

在使用model.load_state_dict进行权重文件加载时,类似keras中load_weights(by_name=True),是严格按照模型的中的各个模块的名称与权重文件中state_dict 的索引进行匹配的。如果不匹配,会报错 Error(s) in loading state_dict

注意

Pytorch与keras中load_weights(by_name=False)设置不同,由于Tensorflow是静态图可以按照网络的拓扑结构,即模型与权重文件的的层次结构进行匹配
Pytorch是动态图,即网络在训练的时候也可能会改变网络结构,因此当load_state_dict(, strict=False)时,出现不匹配的不会报错,但是对应的模块也无法加载字典索引对应的权重文件

故为实现权重文件的有效加载,需要对权重文件的字典索引按照我们重建的模型进行修改,这里我们使用到re模块,对索引字符串进行正则化匹配。

DenseNet

以torch官网给出的DenseNet修改方式为例。备注中有说明由于原权重文件中的网络层名为’norm.1’, ‘relu.1’, ‘conv.1’, ‘norm.2’, ‘relu.2’, ‘conv.2’.,而重建的网络中则直接使用norm1等方式,因此需要去除 ‘.’ 。

def _load_state_dict(model, model_url, progress):
    # '.'s are no longer allowed in module names, but previous _DenseLayer
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
        r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')

    state_dict = load_state_dict_from_url(model_url, progress=progress)
    for key in list(state_dict.keys()):
        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict)

这里的关键是操作便是,匹配的正则表达式

pattern = re.compile(r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')

权重文件中的索引

‘features.denseblock1.denselayer1.norm.1.weight’

模型中的索引

‘features.denseblock1.denselayer1.norm1.weight’

第一次看可能有点懵,确实写的有点复杂,我们从左到右依次来看:
r’ 后面接需要匹配的内容
^ 匹配字符串的开始;
. 转义字符匹配.
*重复零次或更多次
\d 匹配数字
?:norm|relu|conv 匹配三组字符串中的一类
?:[12] 匹配数字1或2
$ 匹配字符串的结尾

其中res.group(*)应该是根据r’后面的括号进行划分的不同group

res.group(0)
‘features.denseblock1.denselayer1.norm.1.weight’
res.group(1)
‘features.denseblock1.denselayer1.norm’
res.group(2)
‘1.weight’

re模块参考

https://www.runoob.com/python/python-reg-expressions.html
https://www.cnblogs.com/shenjianping/p/11647473.html

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

(pytorch_env) legion@legion-Legion-Y7000P-IRX9:~/D$ python3 verify.py /home/legion/pytorch_env/lib/python3.8/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead. warnings.warn( /home/legion/pytorch_env/lib/python3.8/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`. warnings.warn(msg) verify.py:23: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. model.load_state_dict(torch.load(model_path)) Traceback (most recent call last): File "verify.py", line 110, in <module> main() File "verify.py", line 67, in main model, device = load_model(MODEL_PATH, NUM_CLASSES) File "verify.py", line 23, in load_model model.load_state_dict(torch.load(model_path)) File "/home/legion/pytorch_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 2215, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for ResNet: Missing key(s) in state_dict: "fc.weight", "fc.bias". Unexpected key(s) in state_dict: "fc.1.weight", "fc.1.bias". (pytorch_env) legion@legion-Legion-Y7000P-IRX9:~/D$
06-20
D:\anaconda3\python.exe D:\成电——研究生\基于数据驱动的故障诊断研究\数据集汇总\phm-ieee-2012-data-challenge-dataset-master\python\WT\cwt1\1.py 使用设备: cpu D:\成电——研究生\基于数据驱动的故障诊断研究\数据集汇总\phm-ieee-2012-data-challenge-dataset-master\python\WT\cwt1\1.py:120: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. model.load_state_dict(torch.load(model_path)) Traceback (most recent call last): File "D:\成电——研究生\基于数据驱动的故障诊断研究\数据集汇总\phm-ieee-2012-data-challenge-dataset-master\python\WT\cwt1\1.py", line 265, in <module> main() File "D:\成电——研究生\基于数据驱动的故障诊断研究\数据集汇总\phm-ieee-2012-data-challenge-dataset-master\python\WT\cwt1\1.py", line 255, in main visualize_training_results( File "D:\成电——研究生\基于数据驱动的故障诊断研究\数据集汇总\phm-ieee-2012-data-challenge-dataset-master\python\WT\cwt1\1.py", line 120, in visualize_training_results model.load_state_dict(torch.load(model_path)) File "D:\anaconda3\Lib\site-packages\torch\nn\modules\module.py", line 2584, in load_state_dict raise RuntimeError( RuntimeError: Error(s) in loading state_dict for ImprovedCWTCNNHI: Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.1.weight", "features.1.bias", "features.1.running_mean", "features.1.running_var", "features.4.weight", "features.4.bias", "features.5.weight", "features.5.bias", "features.5.running_mean", "features.5.running_var", "features.8.weight", "features.8.bias", "features.9.weight", "features.9.bias", "features.9.running_mean", "features.9.running_var", "features.12.weight", "features.12.bias", "features.13.weight", "features.13.bias", "features.13.running_mean", "features.13.running_var", "classifier.2.weight", "classifier.2.bias", "classifier.3.weight", "classifier.3.bias", "classifier.3.running_mean", "classifier.3.running_var", "classifier.6.weight", "classifier.6.bias", "classifier.7.weight", "classifier.7.bias", "classifier.7.running_mean", "classifier.7.running_var", "classifier.10.weight", "classifier.10.bias". Unexpected key(s) in state_dict: "model_state_dict", "optimizer_state_dict", "history". 进程已结束,退出代码为 1
最新发布
10-18
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to C:\Users\Administrator/.cache\torch\hub\checkpoints\vgg19-dcbb9e9d.pth 38%|███▊ | 209M/548M [2:17:14<3:42:27, 26.6kB/s] Traceback (most recent call last): File "D:\keti\CoupledTPS-main\rotation\Codes\train.py", line 281, in <module> train(args) File "D:\keti\CoupledTPS-main\rotation\Codes\train.py", line 61, in train vgg_model = models.vgg19(pretrained=True) File "D:\miniconda3\envs\pix2pix\lib\site-packages\torchvision\models\_utils.py", line 142, in wrapper return fn(*args, **kwargs) File "D:\miniconda3\envs\pix2pix\lib\site-packages\torchvision\models\_utils.py", line 228, in inner_wrapper return builder(*args, **kwargs) File "D:\miniconda3\envs\pix2pix\lib\site-packages\torchvision\models\vgg.py", line 485, in vgg19 return _vgg("E", False, weights, progress, **kwargs) File "D:\miniconda3\envs\pix2pix\lib\site-packages\torchvision\models\vgg.py", line 105, in _vgg model.load_state_dict(weights.get_state_dict(progress=progress, check_hash=True)) File "D:\miniconda3\envs\pix2pix\lib\site-packages\torchvision\models\_api.py", line 90, in get_state_dict return load_state_dict_from_url(self.url, *args, **kwargs) File "D:\miniconda3\envs\pix2pix\lib\site-packages\torch\hub.py", line 867, in load_state_dict_from_url download_url_to_file(url, cached_file, hash_prefix, progress=progress) File "D:\miniconda3\envs\pix2pix\lib\site-packages\torch\hub.py", line 756, in download_url_to_file raise RuntimeError( RuntimeError: invalid hash value (expected "dcbb9e9d", got "74a9bbb3ed89bb88f5d616ea754470722ea342be169c11526260cac59995de0b")
07-02
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值