简介
pytorch与保存、加载模型有关的常用函数3个:
torch.save(): 保存一个序列化的对象到磁盘,使用的是Python的pickle库来实现的。
torch.load(): 解序列化一个pickled对象并加载到内存当中。
torch.nn.Module.load_state_dict(): 加载一个解序列化的state_dict对象
1、state_dict是什么?
在pytorch中所有可学习的参数保存在model.parameters()中。state_dict是一个python中的字典,保存了各层与其参数张量之间的映射。torch.optim对象也有一个state_dict,它包含了optimizer的state,以及一些超参数。
2、在预测过程中保存和加载模型
2.1仅保存模型参数(推荐存储方式)
# save
torch.save(model.state_dict(), PATH)
# load
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()#切换为测试模式,防止参数更新
注意:
PyTorch 的1.6版本修改了 torch.save(),使用了一种新的基于 zipfile 的文件格式。Load 仍然保留以旧格式