理解 PyTorch Lightning 的内部训练流程确实有助于更好地调试和优化你的模型。让我们详细解析你在调试过程中观察到的调用顺序,以及 PyTorch Lightning 在 trainer.fit(system)
调用时到底在做什么。
PyTorch Lightning 的训练流程概述
PyTorch Lightning 是一个高级封装库,旨在简化 PyTorch 的训练过程。它通过抽象化和模块化的方式,让用户专注于模型的定义和训练逻辑,而无需处理大量的样板代码。核心组件是 LightningModule
,你在代码中通过继承 pl.LightningModule
创建了 NeRF_pl
类。
主要步骤解析
从你提供的调试信息来看,训练流程大致如下:
- 初始化和准备数据
- 配置优化器和调度器
- 设置加速器和精度插件
- 开始训练循环
- 执行训练和验证步骤
让我们逐步深入每一步的细节。
1. 初始化和准备数据
在 main.py
中,训练开始于 trainer.fit(system)
,其中 system
是 NeRF_pl
的实例。
def main():
torch.cuda.empty_cache()
args = get_opts()
system = NeRF_pl(args)
logger = pl.loggers.TensorBoardLogger(save_dir=args.logs_dir, name=args.exp_name, default_hp_metric=False)
ckpt_callback = pl.callbacks.ModelCheckpoint(
dirpath="{}/{}".format(args.ckpts_dir, args.exp_name),
filename="{epoch:d}",
monitor="val/psnr",
mode="max",
save_top_k=-1,
every_n_val_epochs=args.save_every_n_epochs
)
trainer = pl.Trainer(
max_steps=args.max_train_steps,
logger=logger,
callbacks=[ckpt_callback],
resume_from_checkpoint=args.ckpt_path,
gpus=[args.gpu_id],
auto_select_gpus=False,
deterministic=True,
benchmark=True,
weights_summary=None,
num_sanity_val_steps=2,
check_val_every_n_epoch=1,
profiler="simple"
)
trainer.fit(system)
trainer.fit(system)
的作用
trainer.fit(system)
是启动训练的核心调用。它会依次执行以下步骤:
- 调用
prepare_data()
- 设置数据加载器 (
train_dataloader()
和val_dataloader()
) - 配置优化器和调度器 (
configure_optimizers()
) - 设置加速器(如 GPU)和精度插件(如混合精度)
- 开始训练循环,执行
training_step()
和validation_step()
2. 配置优化器和调度器
在调试过程中,你看到代码进入了 configure_optimizers()
方法。这是 LightningModule
中定义的一个方法,用于指定优化器和学习率调度器。
def configure_optimizers(self):
parameters = train_utils.get_parameters(self.models)
self.optimizer = torch.optim.Adam(parameters, lr=self.args.lr, weight_decay=0)
max_epochs = self.get_current_epoch(self.args.max_train_steps)
scheduler = train_utils.get_scheduler(optimizer=self.optimizer, lr_scheduler='step', num_epochs=max_epochs)
return {
'optimizer': self.optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'interval': 'epoch'
}
}
作用解析
-
获取模型参数
parameters = train_utils.get_parameters(self.models)
train_utils.get_parameters(self.models)
提取了self.models
(包括粗略模型和精细模型)的所有可训练参数。
-
配置优化器
self.optimizer = torch.optim.Adam(parameters, lr=self.args.lr, weight_decay=0)
- 使用 Adam 优化器,设置学习率和权重衰减(此处为 0)。
-
配置学习率调度器
max_epochs = self.get_current_epoch(self.args.max_train_steps) scheduler = train_utils.get_scheduler(optimizer=self.optimizer, lr_scheduler='step', num_epochs=max_epochs)
- 计算总的 epoch 数,并配置一个学习率调度器(这里是
step
类型的调度器,按步长调整学习率)。
- 计算总的 epoch 数,并配置一个学习率调度器(这里是
-
返回优化器和调度器
return { 'optimizer': self.optimizer, 'lr_scheduler': { 'scheduler': scheduler, 'interval': 'epoch' } }
- 返回一个字典,包含优化器和学习率调度器的配置。这些配置将由 PyTorch Lightning 内部处理。
返回值解释
optimizers
: 包含所有配置好的优化器(这里是 Adam 优化器)的列表。lr_schedulers
: 包含学习率调度器的配置列表。optimizer_frequencies
: 通常为空,除非你有特定的优化器更新频率需求。
3. 设置加速器和精度插件
接下来,代码进入 accelerator.py
的 setup_optimizers()
方法。
def setup_optimizers(self):
self.optimizers = optimizers
作用解析
self.optimizers = optimizers
: 将优化器列表分配给加速器对象,以便在训练过程中正确使用。
然后,代码跳转到 setup_precision_plugin()
方法。
def setup():
self.setup_precision_plugin(self.precision_plugin)
精度插件的作用
- 精度插件(
precision_plugin
): 负责配置训练时的数值精度(如 32 位浮点、16 位混合精度等)。- 混合精度训练: 通过使用 16 位浮点数来加速训练过程并减少显存占用,同时保持较高的数值精度。
4. 开始训练循环
在完成优化器和精度插件的配置后,PyTorch Lightning 开始实际的训练循环。训练过程会自动调用以下方法:
training_step()
: 处理每个训练批次。validation_step()
: 处理每个验证批次。on_epoch_start()
和on_epoch_end()
: 处理每个 epoch 的开始和结束。
5. trainer.fit(system)
后的详细调用流程
为了更好地理解具体调用顺序,以下是 trainer.fit(system)
调用后的详细流程:
-
初始化优化器和调度器
configure_optimizers()
被调用,返回优化器和调度器的配置。init_optimizers()
初始化这些优化器和调度器。
-
配置加速器
- 调用
accelerator.setup_optimizers()
将优化器配置到加速器环境中。 - 设置精度插件,确保模型在正确的数值精度下运行。
- 调用
-
启动训练循环
- 调用
dispatch()
,这一步开始进入训练循环的核心部分。 - 训练循环包括:
- Epoch Loop: 对于每个 epoch:
- Batch Loop: 对于每个训练批次:
- 调用
training_step()
- 执行前向传播和反向传播
- 优化器更新参数
- 调用
- 验证(如果配置了验证):
- 对于每个验证批次:
- 调用
validation_step()
- 计算验证损失和指标
- 调用
- 对于每个验证批次:
- 学习率调度:
- 更新学习率调度器(如有配置)
- Batch Loop: 对于每个训练批次:
- Epoch Loop: 对于每个 epoch:
- 调用
6. 为什么会有来回跳转?
由于 PyTorch Lightning 的封装性质,很多内部调用会涉及到不同的模块和方法,这导致在调试时看起来像是代码“来回跳转”。实际上,这些跳转是正常的,代表着框架内部不同组件之间的协作。
例如:
trainer.py
中的_run()
方法负责整体的训练过程管理。data_connector.py
负责数据的准备和分发。accelerator.py
负责模型在不同硬件环境下的适配和优化。
这种模块化设计使得代码结构更加清晰,同时也方便用户扩展和定制。
7. 训练是否已经开始?在哪一步?
根据你的描述,训练还没有实际开始。训练开始的标志是在优化器和调度器配置完成后,进入训练循环的核心部分(即调用 training_step()
和 validation_step()
)。
具体确认训练开始的步骤
-
优化器和调度器配置完成
- 已经完成了
configure_optimizers()
和setup_precision_plugin()
的调用。
- 已经完成了
-
进入训练循环
- PyTorch Lightning 会进入训练循环,开始对每个 epoch 进行迭代。
-
调用
training_step()
- 当训练循环开始处理第一个批次时,会调用
training_step()
方法,这标志着实际的训练已经开始。
- 当训练循环开始处理第一个批次时,会调用
如何确认训练开始
在调试时,可以在 training_step()
和 validation_step()
方法中添加断点或打印语句,以确认是否已经进入了实际的训练步骤。
def training_step(self, batch, batch_nb):
print("Entering training_step")
# 现有代码...
当 training_step()
被调用时,你会看到相应的打印输出,确认训练确实已经开始。
总结
trainer.fit(system)
启动了整个训练过程,涉及数据准备、优化器和调度器配置、加速器和精度插件设置等步骤。- 优化器和调度器 通过
configure_optimizers()
配置,并在内部被初始化和设置。 - 精度插件 通过
setup_precision_plugin()
配置,用于控制训练时的数值精度(如混合精度)。 - 训练循环 在所有配置完成后开始,调用
training_step()
和validation_step()
执行具体的训练和验证操作。 - 来回跳转 是因为 PyTorch Lightning 内部模块化设计的结果,正常现象,表示不同组件在协同工作。
- 确认训练开始:通过在
training_step()
中添加调试点或打印语句,可以明确训练是否已经进入实际的训练循环。
希望这些解释能帮助你更好地理解 sat-nerf
项目的训练流程。如果你有进一步的问题或需要更具体的帮助,请随时告诉我!