解决MMsegmentation与MMpretrain中预训练权重中键名不匹配问题

最近在笔者使用MMsegmentation框架进行实验时,有遇到MMsegmentation中backbone的结构与MMpretrain中提供的预训练权重结构不完全匹配的问题,导致与训练权重载入不进去,提出以下解决办法:

众所周知预训练权重其实就是一些字典,权重参数以键值对的形式存储为pth文件,那么解决问题的思路就很简单,使用pytorch的torch.load()加载权重文件,加载完后是一个字典形式,再使用python中的.items()方法获取字典中所有键值对的迭代器,最后用for循环遍历所有的键并改名即可,代码如下:

import torch


def replace_backbone_prefix(weight_key):
    if weight_key.startswith("backbone."):
        return weight_key.replace("backbone.", "")
    else:
        return weight_key


def save_modified_state_dict(filepath, new_state_dict):
    try:
        checkpoint = torch.load(filepath, map_location=torch.device('cpu'))
        if isinstance(checkpoint, dict):
            checkpoint['state_dict'] = new_state_dict
            modified_filepath = filepath.replace('.pth', '_modified.pth')
            torch.save(checkpoint, modified_filepath)
            print("修改后的state_dict已成功保存到文件:", modified_filepath)
            return modified_filepath
        else:
            print("无法在文件中找到state_dict键,请确保你的.pth文件是PyTorch模型的权重文件。")
            return None
    except FileNotFoundError:
        print("文件路径无效,请检查输入的路径是否正确。")
        return None
    except Exception as e:
        print("保存修改后的state_dict时出现错误:", e)
        return None


def print_keys_from_pth(filepath):
    try:
        checkpoint = torch.load(filepath, map_location=torch.device('cpu'))
        if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
            keys = state_dict.keys()
            new_keys = [replace_backbone_prefix(key) for key in keys]
            print("替换前的键名:")
            print(list(keys))
            print("\n替换后的键名:")
            print(list(new_keys))

            # 保存修改后的state_dict
            new_state_dict = {new_key: state_dict[key] for new_key, key in zip(new_keys, keys)}
            save_modified_state_dict(filepath, new_state_dict)
        else:
            print("无法在文件中找到state_dict键,请确保你的.pth文件是PyTorch模型的权重文件。")
    except FileNotFoundError:
        print("文件路径无效,请检查输入的路径是否正确。")
    except Exception as e:
        print("加载权重文件时出现错误:", e)


if __name__ == "__main__":
    filepath = input("请输入.pth权重文件路径:")
    print_keys_from_pth(filepath)

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值