方案1 (推荐方案)
保存模型
torch.save(the_model.state_dict(), PATH)
加载模型
the_model = TheModelClass(*args, **kwargs) # 定义模型
the_model.load_state_dict(torch.load(PATH)) # 读取参数
注:本方案只保留和读取模型参数
。
方案2
保存模型
torch.save(the_model, PATH)
加载模型
the_model = torch.load(PATH)
注:本方案保存和读取整个模型
,实际上就是对象的序列化(serialization),此方案兼容性较差。例如你在windows平台save,在Linux平台load,可能会出问题。PyTorch更新版本,类的定义变了,也可能会出问题。
参考:
- https://stackoverflow.com/questions/42703500/best-way-to-save-a-trained-model-in-pytorch
- https://pytorch.org/docs/stable/notes/serialization.html#recommend-saving-models