最近在笔者使用MMsegmentation框架进行实验时,有遇到MMsegmentation中backbone的结构与MMpretrain中提供的预训练权重结构不完全匹配的问题,导致与训练权重载入不进去,提出以下解决办法:
众所周知预训练权重其实就是一些字典,权重参数以键值对的形式存储为pth文件,那么解决问题的思路就很简单,使用pytorch的torch.load()加载权重文件,加载完后是一个字典形式,再使用python中的.items()方法获取字典中所有键值对的迭代器,最后用for循环遍历所有的键并改名即可,代码如下:
import torch
def replace_backbone_prefix(weight_key):
if weight_key.startswith("backbone."):
return weight_key.replace("backbone.", "")
else:
return weight_key
def save_modified_state_dict(filepath, new_state_dict):
try:
checkpoint = torch.load(filepath, map_location=torch.device('cpu'))
if isinstance(checkpoint, dict):
checkpoint['state_dict'] = new_state_dict
modified_filepath = filepath.replace('.pth', '_modified.pth')
torch.save(checkpoint, modified_filepath)
print("修改后的state_dict已成功保存到文件:", modified_filepath)
return modified_filepath
else:
print("无法在文件中找到state_dict键,请确保你的.pth文件是PyTorch模型的权重文件。")
return None
except FileNotFoundError:
print("文件路径无效,请检查输入的路径是否正确。")
return None
except Exception as e:
print("保存修改后的state_dict时出现错误:", e)
return None
def print_keys_from_pth(filepath):
try:
checkpoint = torch.load(filepath, map_location=torch.device('cpu'))
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
keys = state_dict.keys()
new_keys = [replace_backbone_prefix(key) for key in keys]
print("替换前的键名:")
print(list(keys))
print("\n替换后的键名:")
print(list(new_keys))
# 保存修改后的state_dict
new_state_dict = {new_key: state_dict[key] for new_key, key in zip(new_keys, keys)}
save_modified_state_dict(filepath, new_state_dict)
else:
print("无法在文件中找到state_dict键,请确保你的.pth文件是PyTorch模型的权重文件。")
except FileNotFoundError:
print("文件路径无效,请检查输入的路径是否正确。")
except Exception as e:
print("加载权重文件时出现错误:", e)
if __name__ == "__main__":
filepath = input("请输入.pth权重文件路径:")
print_keys_from_pth(filepath)