trainer.fit解析

理解 PyTorch Lightning 的内部训练流程确实有助于更好地调试和优化你的模型。让我们详细解析你在调试过程中观察到的调用顺序,以及 PyTorch Lightning 在 trainer.fit(system) 调用时到底在做什么。

PyTorch Lightning 的训练流程概述

PyTorch Lightning 是一个高级封装库,旨在简化 PyTorch 的训练过程。它通过抽象化和模块化的方式,让用户专注于模型的定义和训练逻辑,而无需处理大量的样板代码。核心组件是 LightningModule,你在代码中通过继承 pl.LightningModule 创建了 NeRF_pl 类。

主要步骤解析

从你提供的调试信息来看,训练流程大致如下:

  1. 初始化和准备数据
  2. 配置优化器和调度器
  3. 设置加速器和精度插件
  4. 开始训练循环
  5. 执行训练和验证步骤

让我们逐步深入每一步的细节。


1. 初始化和准备数据

main.py 中,训练开始于 trainer.fit(system),其中 systemNeRF_pl 的实例。

def main():
    torch.cuda.empty_cache()
    args = get_opts()
    system = NeRF_pl(args)

    logger = pl.loggers.TensorBoardLogger(save_dir=args.logs_dir, name=args.exp_name, default_hp_metric=False)

    ckpt_callback = pl.callbacks.ModelCheckpoint(
        dirpath="{}/{}".format(args.ckpts_dir, args.exp_name),
        filename="{epoch:d}",
        monitor="val/psnr",
        mode="max",
        save_top_k=-1,
        every_n_val_epochs=args.save_every_n_epochs
    )

    trainer = pl.Trainer(
        max_steps=args.max_train_steps,
        logger=logger,
        callbacks=[ckpt_callback],
        resume_from_checkpoint=args.ckpt_path,
        gpus=[args.gpu_id],
        auto_select_gpus=False,
        deterministic=True,
        benchmark=True,
        weights_summary=None,
        num_sanity_val_steps=2,
        check_val_every_n_epoch=1,
        profiler="simple"
    )

    trainer.fit(system)
trainer.fit(system) 的作用

trainer.fit(system) 是启动训练的核心调用。它会依次执行以下步骤:

  1. 调用 prepare_data()
  2. 设置数据加载器 (train_dataloader()val_dataloader())
  3. 配置优化器和调度器 (configure_optimizers())
  4. 设置加速器(如 GPU)和精度插件(如混合精度)
  5. 开始训练循环,执行 training_step()validation_step()

2. 配置优化器和调度器

在调试过程中,你看到代码进入了 configure_optimizers() 方法。这是 LightningModule 中定义的一个方法,用于指定优化器和学习率调度器。

def configure_optimizers(self):
    parameters = train_utils.get_parameters(self.models)
    self.optimizer = torch.optim.Adam(parameters, lr=self.args.lr, weight_decay=0)

    max_epochs = self.get_current_epoch(self.args.max_train_steps)
    scheduler = train_utils.get_scheduler(optimizer=self.optimizer, lr_scheduler='step', num_epochs=max_epochs)
    return {
        'optimizer': self.optimizer,
        'lr_scheduler': {
            'scheduler': scheduler,
            'interval': 'epoch'
        }
    }
作用解析
  1. 获取模型参数

    parameters = train_utils.get_parameters(self.models)
    
    • train_utils.get_parameters(self.models) 提取了 self.models(包括粗略模型和精细模型)的所有可训练参数。
  2. 配置优化器

    self.optimizer = torch.optim.Adam(parameters, lr=self.args.lr, weight_decay=0)
    
    • 使用 Adam 优化器,设置学习率和权重衰减(此处为 0)。
  3. 配置学习率调度器

    max_epochs = self.get_current_epoch(self.args.max_train_steps)
    scheduler = train_utils.get_scheduler(optimizer=self.optimizer, lr_scheduler='step', num_epochs=max_epochs)
    
    • 计算总的 epoch 数,并配置一个学习率调度器(这里是 step 类型的调度器,按步长调整学习率)。
  4. 返回优化器和调度器

    return {
        'optimizer': self.optimizer,
        'lr_scheduler': {
            'scheduler': scheduler,
            'interval': 'epoch'
        }
    }
    
    • 返回一个字典,包含优化器和学习率调度器的配置。这些配置将由 PyTorch Lightning 内部处理。
返回值解释
  • optimizers: 包含所有配置好的优化器(这里是 Adam 优化器)的列表。
  • lr_schedulers: 包含学习率调度器的配置列表。
  • optimizer_frequencies: 通常为空,除非你有特定的优化器更新频率需求。

3. 设置加速器和精度插件

接下来,代码进入 accelerator.pysetup_optimizers() 方法。

def setup_optimizers(self):
    self.optimizers = optimizers
作用解析
  • self.optimizers = optimizers: 将优化器列表分配给加速器对象,以便在训练过程中正确使用。

然后,代码跳转到 setup_precision_plugin() 方法。

def setup():
    self.setup_precision_plugin(self.precision_plugin)
精度插件的作用
  • 精度插件(precision_plugin: 负责配置训练时的数值精度(如 32 位浮点、16 位混合精度等)。
    • 混合精度训练: 通过使用 16 位浮点数来加速训练过程并减少显存占用,同时保持较高的数值精度。

4. 开始训练循环

在完成优化器和精度插件的配置后,PyTorch Lightning 开始实际的训练循环。训练过程会自动调用以下方法:

  • training_step(): 处理每个训练批次。
  • validation_step(): 处理每个验证批次。
  • on_epoch_start()on_epoch_end(): 处理每个 epoch 的开始和结束。

5. trainer.fit(system) 后的详细调用流程

为了更好地理解具体调用顺序,以下是 trainer.fit(system) 调用后的详细流程:

  1. 初始化优化器和调度器

    • configure_optimizers() 被调用,返回优化器和调度器的配置。
    • init_optimizers() 初始化这些优化器和调度器。
  2. 配置加速器

    • 调用 accelerator.setup_optimizers() 将优化器配置到加速器环境中。
    • 设置精度插件,确保模型在正确的数值精度下运行。
  3. 启动训练循环

    • 调用 dispatch(),这一步开始进入训练循环的核心部分。
    • 训练循环包括:
      • Epoch Loop: 对于每个 epoch:
        • Batch Loop: 对于每个训练批次:
          • 调用 training_step()
          • 执行前向传播和反向传播
          • 优化器更新参数
        • 验证(如果配置了验证):
          • 对于每个验证批次:
            • 调用 validation_step()
            • 计算验证损失和指标
        • 学习率调度:
          • 更新学习率调度器(如有配置)

6. 为什么会有来回跳转?

由于 PyTorch Lightning 的封装性质,很多内部调用会涉及到不同的模块和方法,这导致在调试时看起来像是代码“来回跳转”。实际上,这些跳转是正常的,代表着框架内部不同组件之间的协作。

例如:

  • trainer.py 中的 _run() 方法负责整体的训练过程管理。
  • data_connector.py 负责数据的准备和分发。
  • accelerator.py 负责模型在不同硬件环境下的适配和优化。

这种模块化设计使得代码结构更加清晰,同时也方便用户扩展和定制。


7. 训练是否已经开始?在哪一步?

根据你的描述,训练还没有实际开始。训练开始的标志是在优化器和调度器配置完成后,进入训练循环的核心部分(即调用 training_step()validation_step())。

具体确认训练开始的步骤
  1. 优化器和调度器配置完成

    • 已经完成了 configure_optimizers()setup_precision_plugin() 的调用。
  2. 进入训练循环

    • PyTorch Lightning 会进入训练循环,开始对每个 epoch 进行迭代。
  3. 调用 training_step()

    • 当训练循环开始处理第一个批次时,会调用 training_step() 方法,这标志着实际的训练已经开始。
如何确认训练开始

在调试时,可以在 training_step()validation_step() 方法中添加断点或打印语句,以确认是否已经进入了实际的训练步骤。

def training_step(self, batch, batch_nb):
    print("Entering training_step")
    # 现有代码...

training_step() 被调用时,你会看到相应的打印输出,确认训练确实已经开始。


总结

  • trainer.fit(system) 启动了整个训练过程,涉及数据准备、优化器和调度器配置、加速器和精度插件设置等步骤。
  • 优化器和调度器 通过 configure_optimizers() 配置,并在内部被初始化和设置。
  • 精度插件 通过 setup_precision_plugin() 配置,用于控制训练时的数值精度(如混合精度)。
  • 训练循环 在所有配置完成后开始,调用 training_step()validation_step() 执行具体的训练和验证操作。
  • 来回跳转 是因为 PyTorch Lightning 内部模块化设计的结果,正常现象,表示不同组件在协同工作。
  • 确认训练开始:通过在 training_step() 中添加调试点或打印语句,可以明确训练是否已经进入实际的训练循环。

希望这些解释能帮助你更好地理解 sat-nerf 项目的训练流程。如果你有进一步的问题或需要更具体的帮助,请随时告诉我!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Ajaxm

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值