pytorch保存模型的方式有两种:
- 将整个网络都都保存下来
- 仅保存和加载模型参数(推荐使用这样的方法)
代码部分
方法1
# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')
方法2
# 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))
效果展示
其中
best.mdl
是第二中方法保存的
model.pk
l则是第一种方法保存的