PyTorch Lightning 日志系统完全指南
概述
PyTorch Lightning 提供了一套强大而灵活的日志系统,可以帮助开发者轻松记录训练过程中的各种指标和数据。本文将全面介绍 Lightning 的日志功能,包括内置日志器使用、自定义日志方法、日志频率控制等核心内容。
内置日志器支持
PyTorch Lightning 支持多种流行的日志记录工具,开发者可以根据需求选择合适的日志器:
- TensorBoardLogger:默认日志器,使用 TensorBoard 可视化训练过程
- CSVLogger:将训练指标记录到 CSV 文件
- MLFlowLogger:集成 MLflow 实验跟踪平台
- CometLogger:连接 Comet.ml 实验管理工具
- NeptuneLogger:与 Neptune.ai 平台集成
- WandbLogger:支持 Weights & Biases 实验跟踪
默认情况下,Lightning 使用 TensorBoard 日志器,并将日志存储在 lightning_logs/ 目录中。
from lightning.pytorch import Trainer
# 自动记录到 lightning_logs/ 目录
trainer = Trainer()
查看 TensorBoard 日志:
tensorboard --logdir=lightning_logs/
自定义日志器配置
开发者可以轻松配置不同的日志器:
from lightning.pytorch import loggers as pl_loggers
# 使用 TensorBoard 日志器
tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")
trainer = Trainer(logger=tb_logger)
# 使用 Comet 日志器
comet_logger = pl_loggers.CometLogger(save_dir="logs/")
trainer = Trainer(logger=comet_logger)
甚至可以使用多个日志器同时记录:
tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")
comet_logger = pl_loggers.CometLogger(save_dir="logs/")
trainer = Trainer(logger=[tb_logger, comet_logger])
LightningModule 中的日志记录
自动日志记录
在 LightningModule 中,可以使用 log() 或 log_dict() 方法记录指标:
def training_step(self, batch, batch_idx):
self.log("my_metric", x) # 记录单个指标
# 或者使用字典记录多个指标
self.log_dict({"acc": acc, "recall": recall})
log() 方法提供了丰富的控制选项:
def training_step(self, batch, batch_idx):
self.log("my_loss", loss,
on_step=True, # 记录每一步
on_epoch=True, # 记录整个 epoch
prog_bar=True, # 显示在进度条
logger=True) # 记录到日志器
主要参数说明:
on_step:是否记录当前步骤的指标on_epoch:是否在 epoch 结束时自动累积并记录指标prog_bar:是否显示在进度条logger:是否记录到日志器reduce_fx:epoch 结束时对步骤值的归约函数(默认为 torch.mean)sync_dist:是否跨设备同步指标(分布式训练时使用)
手动记录非标量数据
对于直方图、图像等非标量数据,可以直接使用日志器的实验接口:
def training_step(self):
# 获取当前日志器(如 TensorBoard)
tensorboard = self.logger.experiment
tensorboard.add_image("sample", image)
tensorboard.add_histogram("weights", model_weights)
自定义日志器实现
开发者可以创建自定义日志器,只需继承 Logger 类并实现必要方法:
from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.utilities import rank_zero_only
class MyLogger(Logger):
@property
def name(self):
return "MyLogger"
@property
def version(self):
return "0.1"
@rank_zero_only
def log_hyperparams(self, params):
# 记录超参数
pass
@rank_zero_only
def log_metrics(self, metrics, step):
# 记录指标
pass
日志频率控制
默认情况下,Lightning 每 50 步记录一次日志。可以通过 Trainer 的 log_every_n_steps 参数调整:
trainer = Trainer(log_every_n_steps=10) # 每10步记录一次
进度条配置
可以将任何指标添加到进度条:
def training_step(self, batch, batch_idx):
self.log("my_loss", loss, prog_bar=True)
自定义进度条内容:
from lightning.pytorch.callbacks.progress import TQDMProgressBar
class CustomProgressBar(TQDMProgressBar):
def get_metrics(self, *args, **kwargs):
items = super().get_metrics(*args, **kwargs)
items.pop("v_num", None) # 移除版本号显示
return items
控制台日志配置
可以调整 Lightning 的控制台日志级别:
import logging
# 设置 Lightning 的根日志级别
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
# 模块级日志配置,重定向到文件
logger = logging.getLogger("lightning.pytorch.core")
logger.addHandler(logging.FileHandler("core.log"))
超参数记录
Lightning 会自动在检查点中保存超参数:
lightning_checkpoint = torch.load(filepath)
hyperparams = lightning_checkpoint["hyper_parameters"]
部分日志器还支持记录超参数到实验平台。例如,TensorBoardLogger 会将超参数显示在 hparams 标签页。
# 使用默认指标
def validation_step(self, batch, batch_idx):
self.log("hp_metric", some_scalar)
# 使用自定义或多指标
def on_train_start(self):
self.logger.log_hyperparams(self.hparams,
{"hp/metric_1": 0, "hp/metric_2": 0})
def validation_step(self, batch, batch_idx):
self.log("hp/metric_1", some_scalar_1)
self.log("hp/metric_2", some_scalar_2)
远程文件系统支持
PyTorch Lightning 支持将日志保存到各种文件系统,包括本地文件系统和多个云存储提供商。开发者可以根据需要配置不同的存储后端。
通过本文的介绍,开发者可以全面掌握 PyTorch Lightning 的日志系统,灵活地记录和监控模型训练过程中的各种信息。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



