一、保存
# 保存模型到路径
torch.save(Batch_Net(28*28, 300, 100, 10), r'C:\Users\11868\Desktop\net.pth')
# 保存模型的参数
torch.save(model.state_dict(), r'C:\Users\11868\Desktop\state_dict.pth')
注意:若模型初始化需要指定参数,则保存时要添加参数。
二、载入
# 加载模型
model = torch.load(r'C:\Users\11868\Desktop\net.pth')
# 加载参数
model.load_state_dict(torch.load(r'C:\Users\11868\Desktop\state_dict.pth'))
model.eval() # 将模型改为测试模式
注意:必须调用model.eval(),以便在运行推断之前将dropout和batch规范化层设置为评估模式。如果不这样做,将会产生不一致的推断结果。
本文详细介绍了如何使用PyTorch保存整个模型和仅保存模型参数的方法,并提供了实际的代码示例。同时,强调了在加载模型进行预测前,务必调用model.eval()将模型切换至评估模式的重要性。
1647

被折叠的 条评论
为什么被折叠?



