PyTorch提供了两种保存训练好的模型的方法。
第一种是只保存模型参数,这也是推荐的方法:
#保存
torch.save(the_model.state_dict(), PATH)
#读取
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
第二种方法保存整个模型:
#保存
torch.save(the_model, PATH)
#读取
the_model = torch.load(PATH)
注:PATH的格式:'./model_file_name/the_model_name.tar'
本文详细介绍了PyTorch中两种保存模型的方法:仅保存模型参数和保存整个模型。并提供了具体的代码示例,包括如何保存和读取模型。
4103

被折叠的 条评论
为什么被折叠?



