pytorch中保存模型相关的函数有3个:
torch.save:利用python的pickle模块实现序列化并保存序列化后的objecttorch.load:利用pickle将保存的object反序列化torch.nn.Module.load_state_dict:通过反序列化得到的state_dict读取保存的训练参数
有两种方法保存模型:
1. torch.save(model, path) # 直接保存整个模型
2. torch.save(model.state_dict(), path) # 保存模型的参数
相应地有两种方法加载保存的模型:
1. model = torch.load(path) # 直接加载模型
2. model = Model() # 先初始化一个模型
model.load_state_dict(torch.load(path))

本文介绍了PyTorch在保存和加载模型过程中遇到的设备错误和模型错误,包括如何处理模型默认在GPU上的问题,解决不同PyTorch版本之间的兼容性问题,以及如何正确保存和恢复优化器的状态。总结建议在保存时将模型和优化器参数转到CPU,并根据训练需求决定是否保存优化器状态。
最低0.47元/天 解锁文章
1万+

被折叠的 条评论
为什么被折叠?



