【学习参考】Pytorch 加载、查看预训练模型参数、使用部分预训练模型参数初始化网络
- 取出自己网络的参数字典
- 加载预训练网络的参数字典
- 取出预训练网络的参数字典
- 自己网络和预训练网络结构一致的层,使用预训练网络对应层的参数初始化
# 取出自己网络的参数字典
model_dict = model.state_dict()
# 加载预训练网络的参数字典
pretrained_dict = torch.load("xxxxxx.pth")
# 取出预训练网络的参数字典
keys = []
for k, v in pretrained_dict.items():
keys.append(k)
i = 0
# 自己网络和预训练网络结构一致的层,使用预训练网络对应层的参数初始化
for k, v in model_dict.items():
if v.size() == pretrained_dict[keys[i]].size():
model_dict[k] = pretrained_dict[keys[i]]
#print(model_dict[k])
i = i + 1
model.load_state_dict(model_dict)
2【报错】IndexError:list index out of range
检查list len&#x