PyTorch Lightning CLI 高级配置指南
前言
PyTorch Lightning 是一个轻量级的 PyTorch 封装框架,它通过提供标准化的训练流程简化了深度学习模型的开发。其中,LightningCLI 是一个强大的命令行接口工具,可以极大地简化模型训练和实验配置的过程。本文将深入探讨 LightningCLI 的高级配置技巧,帮助开发者更灵活地使用这一工具。
实例化模式
默认情况下,LightningCLI 在初始化时会自动调用与子命令关联的训练器函数。但有时我们可能需要更精细的控制流程:
cli = LightningCLI(MyModel, run=False) # 默认为True
# 需要手动调用fit方法
cli.trainer.fit(cli.model)
在这种模式下,子命令不会被添加到解析器中。这种模式特别适合需要实现自定义逻辑但又不希望子类化 CLI 的情况,同时还能利用 CLI 的实例化和参数解析能力。
训练器回调与类类型参数
训练器类中最重要的参数之一是 callbacks。与简单的数字或字符串参数不同,callbacks 期望接收一个回调类实例的列表。在配置文件中,每个回调需要以字典形式指定:
trainer:
callbacks:
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
save_weights_only: true
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: 'epoch'
命令行语法示例:
python ... \
--trainer.callbacks+=EarlyStopping \
--trainer.callbacks.patience=5 \
--trainer.callbacks+=LearningRateMonitor \
--trainer.callbacks.logging_interval=epoch
注意使用 + 来追加新回调到列表中,init_args 会应用到前一个追加的回调上。
多模型与数据集配置
CLI 可以编写为通过导入路径和初始化参数指定模型和/或数据模块:
cli = LightningCLI(MyModelBaseClass, MyDataModuleBaseClass,
subclass_mode_model=True,
subclass_mode_data=True)
对应的配置文件示例:
model:
class_path: mycode.mymodels.MyModel
init_args:
decoder_layers: [2, 4]
encoder_layers: 12
data:
class_path: mycode.mydatamodules.MyDataModule
init_args: {...}
包含多个子模块的模型
对于需要多个子模块的复杂模型,可以使用依赖注入模式:
class MyMainModel(LightningModule):
def __init__(self, encoder: nn.Module, decoder: nn.Module):
super().__init__()
self.save_hyperparameters()
self.encoder = encoder
self.decoder = decoder
对应的配置文件:
model:
encoder:
class_path: mycode.myencoders.MyEncoder
init_args: {...}
decoder:
class_path: mycode.mydecoders.MyDecoder
init_args: {...}
固定优化器和调度器
有时我们可能需要固定优化器和学习率调度器,而不是允许任意选择:
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(torch.optim.Adam)
parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR)
这样在配置中,optimizer 和 lr_scheduler 组将只接受指定类的参数。
多优化器和调度器
对于需要多个优化器的场景,可以禁用自动配置:
cli = MyLightningCLI(MyModel, auto_configure_optimizers=False)
模型实现示例:
class MyModel(LightningModule):
def __init__(self, optimizer1: OptimizerCallable, optimizer2: OptimizerCallable):
super().__init__()
self.save_hyperparameters()
self.optimizer1 = optimizer1
self.optimizer2 = optimizer2
def configure_optimizers(self):
return [self.optimizer1(self.parameters()),
self.optimizer2(self.parameters())]
从Python直接运行
虽然 LightningCLI 主要用于命令行工具,但也可以直接从Python运行:
def cli_main(args: ArgsType = None):
cli = LightningCLI(MyModel, ..., args=args)
...
if __name__ == "__main__":
cli_main()
然后在其他Python代码中可以这样调用:
cli_main(["--trainer.max_epochs=100", "--model.encoder_layers=24"])
总结
PyTorch Lightning 的 CLI 工具提供了极大的灵活性,可以满足从简单到复杂的各种深度学习实验配置需求。通过掌握这些高级配置技巧,开发者可以更高效地管理模型训练过程,实现更复杂的实验设计,同时保持代码的整洁和可维护性。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



