错误原因
保存的网络模型,但是读取的时候按照读取参数的方式使用了model.load_state_dict()
,如下
model = Generator(3, 3, 64)
model_weight_path = ''
model.load_state_dict(torch.load(mdoel_weight_path))
model.eval
正确方式
model_weight_path = "./generator.pkl"
model = torch.load(model_weight_path)
model.eval()
总结
-
方法一
torch.save(net.state_dict(),path)
功能:保存训练完的网络的各层参数(即weights和bias)
其中:net.state_dict()
获取各层参数,path是文件存放路径(通常保存文件格式为.pt或.pth)
net2.load_state_dict(torch.load(path))
功能:加载保存到path中的各层参数到神经网络
注意:不可以直接为torch.load_state_dict(path),此函数不能直接接收字符串类型参数 -
方法二
torch.save(net,path)
功能:保存训练完的整个网络模型(不止weights和bias)
net2=torch.load(path)
功能:加载保存到path中的整个神经网络