PyTorch Lightning 日志系统完全指南
概述
PyTorch Lightning 提供了一个强大而灵活的日志系统,可以帮助开发者轻松跟踪和可视化模型训练过程中的各种指标。本文将全面介绍 PyTorch Lightning 的日志功能,包括内置日志记录器的使用、自定义日志记录、日志频率控制以及高级日志技巧。
内置日志记录器
PyTorch Lightning 支持多种流行的日志记录器,包括:
- TensorBoardLogger (默认)
- CSVLogger
- MLFlowLogger
- NeptuneLogger
- WandbLogger
- CometLogger
这些日志记录器会自动绘制"全局步数 vs 周期"图表,部分记录器还会提供额外的可视化功能。
基本使用示例
from lightning.pytorch import Trainer
# 使用默认的TensorBoard记录器,日志保存在lightning_logs/目录
trainer = Trainer()
# 使用自定义记录器
from lightning.pytorch import loggers as pl_loggers
tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")
trainer = Trainer(logger=tb_logger)
# 同时使用多个记录器
tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")
comet_logger = pl_loggers.CometLogger(save_dir="logs/")
trainer = Trainer(logger=[tb_logger, comet_logger])
从LightningModule中记录日志
PyTorch Lightning 提供了自动和手动两种日志记录方式。
自动日志记录
使用log()
或log_dict()
方法可以方便地记录指标:
def training_step(self, batch, batch_idx):
self.log("my_metric", x) # 记录单个指标
self.log_dict({"acc": acc, "recall": recall}) # 同时记录多个指标
log()
方法提供了丰富的控制选项:
def training_step(self, batch, batch_idx):
self.log("my_loss", loss,
on_step=True, # 记录每一步
on_epoch=True, # 记录整个周期
prog_bar=True, # 显示在进度条
logger=True) # 记录到日志文件
日志行为默认设置
不同钩子函数中的默认日志行为:
| 钩子函数 | on_step | on_epoch | |---------|---------|----------| | 训练周期相关钩子 | False | True | | 训练批次相关钩子 | True | False | | 验证相关钩子 | False | True |
手动记录非标量数据
对于直方图、图像等非标量数据,可以直接使用记录器的实验接口:
def training_step(self):
# 获取当前记录器
tensorboard = self.logger.experiment
tensorboard.add_image("sample", image)
tensorboard.add_histogram("weights", model.weights)
自定义日志记录器
你可以通过继承Logger
类创建自定义记录器:
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.utilities import rank_zero_only
class MyLogger(Logger):
@property
def name(self):
return "MyLogger"
@property
def version(self):
return "0.1"
@rank_zero_only
def log_hyperparams(self, params):
# 记录超参数
pass
@rank_zero_only
def log_metrics(self, metrics, step):
# 记录指标
pass
控制日志频率
日志记录频率
默认每50步记录一次日志,可以通过log_every_n_steps
调整:
trainer = Trainer(log_every_n_steps=10) # 每10步记录一次
日志写入频率
不同记录器有不同的刷新频率,例如CSVLogger可以通过flush_logs_every_n_steps
设置。
进度条集成
可以将任何指标添加到进度条:
def training_step(self, batch, batch_idx):
self.log("my_loss", loss, prog_bar=True)
自定义进度条内容:
from lightning.pytorch.callbacks.progress import TQDMProgressBar
class CustomProgressBar(TQDMProgressBar):
def get_metrics(self, *args, **kwargs):
items = super().get_metrics(*args, **kwargs)
items.pop("v_num", None) # 移除版本号显示
return items
控制台日志配置
可以调整Lightning的控制台日志级别:
import logging
# 设置全局日志级别
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
# 模块级日志配置
logger = logging.getLogger("lightning.pytorch.core")
logger.addHandler(logging.FileHandler("core.log"))
超参数记录
检查点会自动保存超参数:
lightning_checkpoint = torch.load(filepath)
hyperparams = lightning_checkpoint["hyper_parameters"]
部分记录器支持超参数可视化,例如TensorBoard的hparams标签页:
# 使用默认指标
def validation_step(self, batch, batch_idx):
self.log("hp_metric", some_scalar)
# 使用自定义或多指标
def on_train_start(self):
self.logger.log_hyperparams(self.hparams,
{"hp/metric_1": 0, "hp/metric_2": 0})
def validation_step(self, batch, batch_idx):
self.log("hp/metric_1", some_scalar_1)
self.log("hp/metric_2", some_scalar_2)
远程文件系统支持
PyTorch Lightning 支持将日志保存到多种远程文件系统,包括云存储服务。具体配置可参考相关文档。
总结
PyTorch Lightning 的日志系统提供了从简单到高级的全面日志功能,无论是初学者还是高级用户都能找到适合自己需求的日志方案。通过合理配置日志系统,可以大大提高模型训练和调试的效率。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考