深度学习模型的训练、评估与分类问题处理
1. 模型的保存与加载
为了让我们的 StepByStep 类更加完善,现在为其添加保存和加载模型的功能。大部分代码与之前的类似,只是这里使用类的属性而非局部变量。
1.1 保存模型
def save_checkpoint(self, filename):
# Builds dictionary with all elements for resuming training
checkpoint = {
'epoch': self.total_epochs,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'loss': self.losses,
'val_loss': self.val_losses
}
torch.save(checkpoint, filename)
setattr(StepByStep, 'save_checkpoint', save_checkpoint)
1.2 加载模型
def load_checkpoint(self, filename):
# Loads dictionary
checkpoint = torch.load
超级会员免费看
订阅专栏 解锁全文
6876

被折叠的 条评论
为什么被折叠?



