PyTorch Lightning 基础教程:模型检查点的保存与加载
pytorch-lightning 项目地址: https://gitcode.com/gh_mirrors/pyt/pytorch-lightning
什么是模型检查点?
在深度学习模型训练过程中,模型性能会随着数据不断输入而发生变化。检查点(Checkpoint)机制是保存训练过程中模型状态的最佳实践,它能够在训练的关键节点保存模型版本。训练完成后,开发者可以选择性能最优的检查点作为最终模型。
检查点还有另一个重要作用:当训练意外中断时,可以从最近的检查点恢复训练,避免从头开始。PyTorch Lightning 的检查点与原生 PyTorch 完全兼容,可以无缝切换使用。
检查点包含哪些内容?
PyTorch Lightning 的检查点不仅保存模型权重,还完整记录了训练环境的所有关键状态,即使在复杂的分布式训练场景下也能完美恢复。一个完整的检查点包含:
- 16位精度训练的缩放因子(如使用混合精度训练)
- 当前训练周期(epoch)
- 全局训练步数(global step)
- LightningModule 的状态字典(state_dict)
- 所有优化器的状态
- 所有学习率调度器的状态
- 所有回调函数的状态(针对有状态的回调)
- 数据模块的状态(针对有状态的数据模块)
- 模型初始化时的超参数
- 数据模块初始化时的超参数
- 训练循环的状态
如何保存检查点?
PyTorch Lightning 会自动在当前工作目录保存检查点,记录最后一个训练周期的状态:
# 使用默认Trainer即可自动启用检查点功能
trainer = Trainer()
如需自定义检查点保存路径,可使用 default_root_dir
参数:
# 在每个epoch结束时将检查点保存到指定路径
trainer = Trainer(default_root_dir="some/path/")
从检查点加载模型
加载带有权重和超参数的 LightningModule:
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
# 关闭随机性、dropout等训练专用层
model.eval()
# 使用模型进行预测
y_hat = model(x)
超参数的保存与使用
在 LightningModule 中,可以通过 self.save_hyperparameters()
自动保存所有初始化参数:
class MyLightningModule(LightningModule):
def __init__(self, learning_rate, another_parameter, *args, **kwargs):
super().__init__()
self.save_hyperparameters()
保存的超参数可通过检查点的 "hyper_parameters" 键访问:
checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
print(checkpoint["hyper_parameters"])
# 输出示例: {"learning_rate": 0.01, "another_parameter": "value"}
加载的模型也可以直接访问这些超参数:
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
print(model.learning_rate)
使用不同参数初始化
如果使用了 self.save_hyperparameters()
,可以在加载时覆盖原始超参数:
# 原始训练保存的模型使用这些参数
LitModel(in_dim=32, out_dim=10)
# 加载时使用保存的参数
model = LitModel.load_from_checkpoint(PATH)
# 加载时覆盖部分参数
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
对于大型PyTorch模块参数,可以选择不保存为超参数,但加载时必须提供这些参数:
class LitAutoencoder(L.LightningModule):
def __init__(self, encoder, decoder):
...
model = LitAutoEncoder.load_from_checkpoint(PATH, encoder=encoder, decoder=decoder)
与原生PyTorch模块的互操作
PyTorch Lightning 检查点与原生 PyTorch 模块完全兼容。例如,对于以下模型结构:
class Encoder(nn.Module):
...
class Decoder(nn.Module):
...
class Autoencoder(L.LightningModule):
def __init__(self, encoder, decoder, *args, **kwargs):
super().__init__()
self.encoder = encoder
self.decoder = decoder
autoencoder = Autoencoder(Encoder(), Decoder())
训练完成后,可以提取特定组件的权重:
checkpoint = torch.load(CKPT_PATH)
encoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("encoder.")}
decoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("decoder.")}
禁用检查点功能
如需完全禁用检查点功能:
trainer = Trainer(enable_checkpointing=False)
恢复完整训练状态
如需恢复完整的训练状态(包括epoch、step、学习率调度器等):
model = LitModel()
trainer = Trainer()
# 自动恢复模型、训练周期、步数、学习率调度器等状态
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
通过这套完善的检查点机制,PyTorch Lightning 为用户提供了灵活可靠的模型保存与恢复方案,无论是用于模型部署、继续训练还是结果复现,都能得心应手。
pytorch-lightning 项目地址: https://gitcode.com/gh_mirrors/pyt/pytorch-lightning
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考