PyTorch Lightning CLI 高级配置指南

PyTorch Lightning CLI 高级配置指南

【免费下载链接】pytorch-lightning Lightning-AI/pytorch-lightning: PyTorch Lightning 是一个轻量级的高级接口,用于简化 PyTorch 中深度学习模型的训练流程。它抽象出了繁杂的工程细节,使研究者能够专注于模型本身的逻辑和实验设计,同时仍能充分利用PyTorch底层的灵活性。 【免费下载链接】pytorch-lightning 项目地址: https://gitcode.com/gh_mirrors/py/pytorch-lightning

前言

PyTorch Lightning 是一个轻量级的 PyTorch 封装框架,它通过提供标准化的训练流程简化了深度学习模型的开发。其中,LightningCLI 是一个强大的命令行接口工具,可以极大地简化模型训练和实验配置的过程。本文将深入探讨 LightningCLI 的高级配置技巧,帮助开发者更灵活地使用这一工具。

实例化模式

默认情况下,LightningCLI 在初始化时会自动调用与子命令关联的训练器函数。但有时我们可能需要更精细的控制流程:

cli = LightningCLI(MyModel, run=False)  # 默认为True
# 需要手动调用fit方法
cli.trainer.fit(cli.model)

在这种模式下,子命令不会被添加到解析器中。这种模式特别适合需要实现自定义逻辑但又不希望子类化 CLI 的情况,同时还能利用 CLI 的实例化和参数解析能力。

训练器回调与类类型参数

训练器类中最重要的参数之一是 callbacks。与简单的数字或字符串参数不同,callbacks 期望接收一个回调类实例的列表。在配置文件中,每个回调需要以字典形式指定:

trainer:
  callbacks:
    - class_path: lightning.pytorch.callbacks.ModelCheckpoint
      init_args:
        save_weights_only: true
    - class_path: lightning.pytorch.callbacks.LearningRateMonitor
      init_args:
        logging_interval: 'epoch'

命令行语法示例:

python ... \
    --trainer.callbacks+=EarlyStopping \
    --trainer.callbacks.patience=5 \
    --trainer.callbacks+=LearningRateMonitor \
    --trainer.callbacks.logging_interval=epoch

注意使用 + 来追加新回调到列表中,init_args 会应用到前一个追加的回调上。

多模型与数据集配置

CLI 可以编写为通过导入路径和初始化参数指定模型和/或数据模块:

cli = LightningCLI(MyModelBaseClass, MyDataModuleBaseClass, 
                  subclass_mode_model=True, 
                  subclass_mode_data=True)

对应的配置文件示例:

model:
  class_path: mycode.mymodels.MyModel
  init_args:
    decoder_layers: [2, 4]
    encoder_layers: 12
data:
  class_path: mycode.mydatamodules.MyDataModule
  init_args: {...}

包含多个子模块的模型

对于需要多个子模块的复杂模型,可以使用依赖注入模式:

class MyMainModel(LightningModule):
    def __init__(self, encoder: nn.Module, decoder: nn.Module):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = encoder
        self.decoder = decoder

对应的配置文件:

model:
  encoder:
    class_path: mycode.myencoders.MyEncoder
    init_args: {...}
  decoder:
    class_path: mycode.mydecoders.MyDecoder
    init_args: {...}

固定优化器和调度器

有时我们可能需要固定优化器和学习率调度器,而不是允许任意选择:

class MyLightningCLI(LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.add_optimizer_args(torch.optim.Adam)
        parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR)

这样在配置中,optimizerlr_scheduler 组将只接受指定类的参数。

多优化器和调度器

对于需要多个优化器的场景,可以禁用自动配置:

cli = MyLightningCLI(MyModel, auto_configure_optimizers=False)

模型实现示例:

class MyModel(LightningModule):
    def __init__(self, optimizer1: OptimizerCallable, optimizer2: OptimizerCallable):
        super().__init__()
        self.save_hyperparameters()
        self.optimizer1 = optimizer1
        self.optimizer2 = optimizer2

    def configure_optimizers(self):
        return [self.optimizer1(self.parameters()), 
                self.optimizer2(self.parameters())]

从Python直接运行

虽然 LightningCLI 主要用于命令行工具,但也可以直接从Python运行:

def cli_main(args: ArgsType = None):
    cli = LightningCLI(MyModel, ..., args=args)
    ...

if __name__ == "__main__":
    cli_main()

然后在其他Python代码中可以这样调用:

cli_main(["--trainer.max_epochs=100", "--model.encoder_layers=24"])

总结

PyTorch Lightning 的 CLI 工具提供了极大的灵活性,可以满足从简单到复杂的各种深度学习实验配置需求。通过掌握这些高级配置技巧,开发者可以更高效地管理模型训练过程,实现更复杂的实验设计,同时保持代码的整洁和可维护性。

【免费下载链接】pytorch-lightning Lightning-AI/pytorch-lightning: PyTorch Lightning 是一个轻量级的高级接口,用于简化 PyTorch 中深度学习模型的训练流程。它抽象出了繁杂的工程细节,使研究者能够专注于模型本身的逻辑和实验设计,同时仍能充分利用PyTorch底层的灵活性。 【免费下载链接】pytorch-lightning 项目地址: https://gitcode.com/gh_mirrors/py/pytorch-lightning

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值