只需要将预训练网络和新网络的键值对拿出来作比较,将新网络中没有的键值对去掉,然后再更新网络权重即可,代码如下:
#加载model,model是自己定义好的模型
resnet50 = models.resnet50(pretrained=True)
model =Net(...)
#读取参数 预训练参数和当前网络参数
pretrained_dict =resnet50.state_dict()
model_dict = model.state_dict()
#将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)