本文来自pytorch官网,主要讲述如何保存和导入模型。
一般而言,保存和导入模型我们必须得熟悉三个核心函数:
1.torch.save
:将序列化的对象保存起来,可以使用这个函数去保存模型,张量和字典。并且使用Python的pickle
去实现序列化(serialization)。
2.torch.load
:使用pickle
的unpickling
工具将pickled
对象文件反序列化到内存中。
3.torch.nn.Module.load_state_dict
:使用反序列state_dict
导入模型参数。
一、什么是state_dict
?
torch.nn.Module
中的可学习的参数(weights和bias)可以使用mode.parameters()
获得。state_dict是一个简单的python字典对象,它将每个层中的参数映射到张量中。
具有可学习参数的层(卷积层、线性层等)才有model's state_dict
中的条目。优化器对象(torch.optim
)也有state_dict
,其中包含关于优化器状态以及所使用的超参数的信息。
二、保存和加载模型
Save::torch.save(model.state_dict(), PATH)