1. 预训练参数保存
def savecheckpoint(model,dir):
para=model.state_dict()
torch.save(para, dir)
2. 预训练参数加载
def LoadParameter(_structure, _parameterDir):
"""
:param _structure: model
:param _parameterDir: 参数位置
:return:
"""
checkpoint = torch.load(_parameterDir)
pretrained_state_dict = checkpoint['state_dict']
model_state_dict = _structure.state_dict()
for key in pretrained_state_dict:
if key in model_state_dict:
model_state_dict[key] = pretrained_state_dict[key]
"""
another method:
pretrained_dict = {k: v for k, v in pretrained_state_dict.items() if k in model_state_dict}
model_state_dict.update(pretrained_dict)
"""
_structure.load_state_dict(model_state_dict)
model = torch.nn.DataParallel(_structure).cuda()
return model

本文详细介绍了如何使用PyTorch保存和加载预训练模型的参数,包括定义保存函数和加载函数,确保模型能够在训练后被正确保存,并在需要时重新加载用于后续任务。
7109

被折叠的 条评论
为什么被折叠?



