PyTorch Lightning 基础教程:模型检查点的保存与加载

PyTorch Lightning 基础教程:模型检查点的保存与加载

【免费下载链接】pytorch-lightning Lightning-AI/pytorch-lightning: PyTorch Lightning 是一个轻量级的高级接口,用于简化 PyTorch 中深度学习模型的训练流程。它抽象出了繁杂的工程细节,使研究者能够专注于模型本身的逻辑和实验设计,同时仍能充分利用PyTorch底层的灵活性。 【免费下载链接】pytorch-lightning 项目地址: https://gitcode.com/gh_mirrors/py/pytorch-lightning

什么是模型检查点?

在深度学习模型训练过程中,检查点(Checkpoint)是一个至关重要的概念。它记录了模型在特定训练阶段的完整状态,包括模型权重、优化器状态、学习率调度器状态等关键信息。

检查点主要有两大用途:

  1. 模型版本管理:在训练过程中定期保存检查点,可以保留模型在不同训练阶段的状态,便于后期选择性能最佳的版本
  2. 训练恢复:当训练过程意外中断时,可以从最近的检查点恢复训练,避免从头开始

PyTorch Lightning 的检查点机制设计得非常完善,不仅适用于 Lightning 框架本身,也可以直接用于原生 PyTorch 项目。

检查点的内部结构

PyTorch Lightning 的检查点文件包含了恢复模型训练所需的全部信息,即使在复杂的分布式训练环境中也能完美工作。一个典型的检查点包含以下内容:

  • 模型相关:LightningModule 的状态字典(state_dict)
  • 训练状态:当前训练周期(epoch)、全局步数(global step)
  • 优化相关:所有优化器的状态、所有学习率调度器的状态
  • 精度设置:16位精度训练的缩放因子(如使用混合精度训练)
  • 回调状态:所有有状态回调(stateful callbacks)的当前状态
  • 数据模块:数据模块(DataModule)的状态(如有状态数据模块)
  • 配置信息:模型和数据模块初始化时的所有超参数
  • 训练循环:训练循环(Loops)的当前状态

保存检查点

PyTorch Lightning 会自动在训练过程中保存检查点,默认保存在当前工作目录下。

# 最简单的使用方式,自动保存检查点
trainer = Trainer()

如果需要自定义保存路径,可以使用 default_root_dir 参数:

# 指定检查点保存路径
trainer = Trainer(default_root_dir="your/custom/path/")

从检查点加载模型

加载带有权重和超参数的 LightningModule 非常简单:

# 加载检查点
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")

# 将模型设置为评估模式(关闭dropout等随机操作)
model.eval()

# 使用模型进行预测
y_hat = model(x)

超参数管理

PyTorch Lightning 提供了便捷的超参数保存机制:

class MyLightningModule(LightningModule):
    def __init__(self, learning_rate, another_param, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()  # 自动保存所有传入的超参数

加载模型后可以直接访问这些超参数:

model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
print(model.learning_rate)  # 直接访问保存的超参数

覆盖超参数

加载检查点时可以覆盖原有的超参数:

# 原始训练时使用的参数
LitModel(in_dim=32, out_dim=10)

# 加载时使用原始参数
model = LitModel.load_from_checkpoint(PATH)

# 加载时覆盖in_dim参数
model = LitModel.load_from_checkpoint(PATH, in_dim=128)

对于大型模块参数(如完整的PyTorch模块),可以这样处理:

class LitAutoencoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        ...

# 加载时需要显式传递encoder和decoder
model = LitAutoEncoder.load_from_checkpoint(PATH, encoder=encoder, decoder=decoder)

与原生PyTorch的互操作性

PyTorch Lightning的检查点完全兼容原生PyTorch。例如,假设我们有以下模型结构:

class Encoder(nn.Module):
    ...

class Decoder(nn.Module):
    ...

class Autoencoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

训练完成后,可以这样提取子模块的权重:

checkpoint = torch.load(CKPT_PATH)
encoder_weights = {k: v for k, v in checkpoint["state_dict"].items() 
                  if k.startswith("encoder.")}
decoder_weights = {k: v for k, v in checkpoint["state_dict"].items() 
                  if k.startswith("decoder.")}

禁用检查点功能

如果确实不需要检查点功能,可以这样禁用:

trainer = Trainer(enable_checkpointing=False)

恢复训练状态

要从检查点完全恢复训练状态(包括优化器状态、学习率调度器等),应该这样操作:

model = LitModel()
trainer = Trainer()

# 自动恢复模型、训练周期、步数、学习率调度器等所有状态
trainer.fit(model, ckpt_path="path/to/checkpoint.ckpt")

注意:旧版的 resume_from_checkpoint 参数已在 PyTorch Lightning 1.0.0 及以上版本中废弃,请使用 fit() 方法中的 ckpt_path 参数替代。

最佳实践建议

  1. 定期保存:对于长时间训练的任务,建议设置合理的检查点保存频率
  2. 版本管理:重要的检查点应该备份并标注对应的训练配置和性能指标
  3. 存储空间:检查点文件可能较大,需要确保有足够的存储空间
  4. 恢复验证:重要的检查点应该验证其是否可以正确恢复训练
  5. 生产环境:在生产环境中,建议实现自动化的检查点管理和恢复机制

通过合理使用检查点功能,可以大大提高深度学习项目的开发效率和可靠性。

【免费下载链接】pytorch-lightning Lightning-AI/pytorch-lightning: PyTorch Lightning 是一个轻量级的高级接口,用于简化 PyTorch 中深度学习模型的训练流程。它抽象出了繁杂的工程细节,使研究者能够专注于模型本身的逻辑和实验设计,同时仍能充分利用PyTorch底层的灵活性。 【免费下载链接】pytorch-lightning 项目地址: https://gitcode.com/gh_mirrors/py/pytorch-lightning

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值