在魔改模型后还想要加载原先能加载的部分训练权重
# 加载预训练权重
pretrained_dict = torch.load(pretrained_path)
if 'state_dict' in pretrained_dict:
pretrained_dict = pretrained_dict['state_dict']
# 过滤掉不匹配的键
model_dict = self.resnet.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
model_dict.update(pretrained_dict)
self.resnet.load_state_dict(model_dict)
`
1004

被折叠的 条评论
为什么被折叠?



