PyTorch模型保存与加载(原文链接)
(原文更详细)
常用:
保存/加载 state_dict:
保存
torch.save(model.state_dict(), PATH)
加载
model = 自己的网络(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
PyTorch模型保存与加载(原文链接)
(原文更详细)
常用:
保存/加载 state_dict:
保存
torch.save(model.state_dict(), PATH)
加载
model = 自己的网络(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()