导入预训练模型在通常情况下都能加快模型收敛,提升模型性能。但根据实际任务需求,自己搭建的模型往往和通用的Backbone并不能做到网络层的完全一致,无非就是少一些层和多一些层两种情况。
1. 自己模型层数较少
net = ... # net为自己的模型
save_model = torch.load('path_of_pretrained_model') # 获取预训练模型字典(键值对)
model_dict = net.state_dict() # 获取自己模型字典(键值对)
# 新定义字典,用来获取自己模型中对应层的预训练模型中的参数
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
model_dict.update(state_dict) # 更新自己模型字典中键值对
net.load_state_dict(model_dict) # 加载参数
其中update:对于state_dict和model_dict都有的键值对,前者对应的值会替换后者,若前者有后者没有的键值对,则会添加这些键值对到后者字典中。
2. 自己模型层数较多
model_dict = model.state_dict() # 自己模型字典
model_dict.update(pretrained_model) # 直接将预训练模型参数更新进来
model.load_state_dict(model_dict) # 加载
结果与预训练模型对应的层权重被加载了,其它层则为默认初始化。
由于模型参数都为字典形式存在,可以用字典的增删方式进行更灵活的操作,这里就不叙述了。