- 错误代码:之前训练好了一个网络模型,并把它保存为’Net.pkl’,今天加载的时候出现了错误,加载代码为
model = torch.load('Net.pkl')
-
错误信息为:
-
错误原因:之前保存网络时用的方法是
torch.save(model, 'Nei.pkl')
,这样保存下来的Net.pkl是一个状态字典,而不是模型本身,也就是说Net.pkl中保存的只是网络的参数,而没有网络结构。 -
改正方法:应该先载入网络结构,再导入网络的参数,修改后代码如下,其中
BiGRU()
是写好的神经网络类。
model = BiGRU() # 导入网络结构
model.load_state_dict(torch.load('Net.pkl')) # 导入网络的参数
这样就不会报错了!