Pytorch 模型存储
pytorch 模型存储有两种方法:
存储参数
torch.save(the_model.state_dict(),PATH)
the_model=THe_model()
the_model.load_state_dict(torch.load(PATH))
存储整个模型
torch.save(the_model,PATH)
the_model=torch.load(PATH)
官方文档推荐使用第一种参数方法,第二种方法模型的恢复受限于文件路径,文件系统等,容易出问题,不易在其他工程中使用。