保存加载用于推理的常规Checkpoint/或继续训练
https://blog.youkuaiyun.com/qq_38765642/article/details/109784913
if (epoch+1) % checkpoint_interval == 0:
checkpoint = {"model_state_dict": net.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch}
path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch)
torch.save(checkpoint, path_checkpoint)
#或者
#保存
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, PATH)
#加载
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()
# - 或者 -
model.train()
checkpoint检查点:不仅保存模型的参数,优化器参数,还有loss,epoch等(相当于一个保存模型的文件夹)
state_dict可以用于保存模型的参数 及优化。
PyTorch模型保存与加载:检查点与状态字典
该博客介绍了如何在PyTorch中保存和加载模型的检查点,包括模型参数、优化器状态以及训练的epoch和loss。通过使用`torch.save()`和`torch.load()`,可以在训练过程中定期保存模型,并在之后的推理或继续训练时加载这些信息。检查点文件包含了模型的重要状态,而state_dict则用于存储模型参数。
3万+

被折叠的 条评论
为什么被折叠?



