PyTorch Lightning 日志系统完全指南

PyTorch Lightning 日志系统完全指南

【免费下载链接】pytorch-lightning Lightning-AI/pytorch-lightning: PyTorch Lightning 是一个轻量级的高级接口,用于简化 PyTorch 中深度学习模型的训练流程。它抽象出了繁杂的工程细节,使研究者能够专注于模型本身的逻辑和实验设计,同时仍能充分利用PyTorch底层的灵活性。 【免费下载链接】pytorch-lightning 项目地址: https://gitcode.com/gh_mirrors/py/pytorch-lightning

概述

PyTorch Lightning 提供了一套强大而灵活的日志系统,可以帮助开发者轻松记录训练过程中的各种指标和数据。本文将全面介绍 Lightning 的日志功能,包括内置日志器使用、自定义日志方法、日志频率控制等核心内容。

内置日志器支持

PyTorch Lightning 支持多种流行的日志记录工具,开发者可以根据需求选择合适的日志器:

  • TensorBoardLogger:默认日志器,使用 TensorBoard 可视化训练过程
  • CSVLogger:将训练指标记录到 CSV 文件
  • MLFlowLogger:集成 MLflow 实验跟踪平台
  • CometLogger:连接 Comet.ml 实验管理工具
  • NeptuneLogger:与 Neptune.ai 平台集成
  • WandbLogger:支持 Weights & Biases 实验跟踪

默认情况下,Lightning 使用 TensorBoard 日志器,并将日志存储在 lightning_logs/ 目录中。

from lightning.pytorch import Trainer

# 自动记录到 lightning_logs/ 目录
trainer = Trainer()

查看 TensorBoard 日志:

tensorboard --logdir=lightning_logs/

自定义日志器配置

开发者可以轻松配置不同的日志器:

from lightning.pytorch import loggers as pl_loggers

# 使用 TensorBoard 日志器
tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")
trainer = Trainer(logger=tb_logger)

# 使用 Comet 日志器
comet_logger = pl_loggers.CometLogger(save_dir="logs/")
trainer = Trainer(logger=comet_logger)

甚至可以使用多个日志器同时记录:

tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")
comet_logger = pl_loggers.CometLogger(save_dir="logs/")
trainer = Trainer(logger=[tb_logger, comet_logger])

LightningModule 中的日志记录

自动日志记录

在 LightningModule 中,可以使用 log()log_dict() 方法记录指标:

def training_step(self, batch, batch_idx):
    self.log("my_metric", x)  # 记录单个指标
    
    # 或者使用字典记录多个指标
    self.log_dict({"acc": acc, "recall": recall})

log() 方法提供了丰富的控制选项:

def training_step(self, batch, batch_idx):
    self.log("my_loss", loss, 
             on_step=True,    # 记录每一步
             on_epoch=True,   # 记录整个 epoch
             prog_bar=True,   # 显示在进度条
             logger=True)     # 记录到日志器

主要参数说明:

  • on_step:是否记录当前步骤的指标
  • on_epoch:是否在 epoch 结束时自动累积并记录指标
  • prog_bar:是否显示在进度条
  • logger:是否记录到日志器
  • reduce_fx:epoch 结束时对步骤值的归约函数(默认为 torch.mean)
  • sync_dist:是否跨设备同步指标(分布式训练时使用)

手动记录非标量数据

对于直方图、图像等非标量数据,可以直接使用日志器的实验接口:

def training_step(self):
    # 获取当前日志器(如 TensorBoard)
    tensorboard = self.logger.experiment
    tensorboard.add_image("sample", image)
    tensorboard.add_histogram("weights", model_weights)

自定义日志器实现

开发者可以创建自定义日志器,只需继承 Logger 类并实现必要方法:

from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.utilities import rank_zero_only

class MyLogger(Logger):
    @property
    def name(self):
        return "MyLogger"

    @property
    def version(self):
        return "0.1"

    @rank_zero_only
    def log_hyperparams(self, params):
        # 记录超参数
        pass

    @rank_zero_only
    def log_metrics(self, metrics, step):
        # 记录指标
        pass

日志频率控制

默认情况下,Lightning 每 50 步记录一次日志。可以通过 Trainerlog_every_n_steps 参数调整:

trainer = Trainer(log_every_n_steps=10)  # 每10步记录一次

进度条配置

可以将任何指标添加到进度条:

def training_step(self, batch, batch_idx):
    self.log("my_loss", loss, prog_bar=True)

自定义进度条内容:

from lightning.pytorch.callbacks.progress import TQDMProgressBar

class CustomProgressBar(TQDMProgressBar):
    def get_metrics(self, *args, **kwargs):
        items = super().get_metrics(*args, **kwargs)
        items.pop("v_num", None)  # 移除版本号显示
        return items

控制台日志配置

可以调整 Lightning 的控制台日志级别:

import logging

# 设置 Lightning 的根日志级别
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)

# 模块级日志配置,重定向到文件
logger = logging.getLogger("lightning.pytorch.core")
logger.addHandler(logging.FileHandler("core.log"))

超参数记录

Lightning 会自动在检查点中保存超参数:

lightning_checkpoint = torch.load(filepath)
hyperparams = lightning_checkpoint["hyper_parameters"]

部分日志器还支持记录超参数到实验平台。例如,TensorBoardLogger 会将超参数显示在 hparams 标签页。

# 使用默认指标
def validation_step(self, batch, batch_idx):
    self.log("hp_metric", some_scalar)

# 使用自定义或多指标
def on_train_start(self):
    self.logger.log_hyperparams(self.hparams, 
                              {"hp/metric_1": 0, "hp/metric_2": 0})

def validation_step(self, batch, batch_idx):
    self.log("hp/metric_1", some_scalar_1)
    self.log("hp/metric_2", some_scalar_2)

远程文件系统支持

PyTorch Lightning 支持将日志保存到各种文件系统,包括本地文件系统和多个云存储提供商。开发者可以根据需要配置不同的存储后端。

通过本文的介绍,开发者可以全面掌握 PyTorch Lightning 的日志系统,灵活地记录和监控模型训练过程中的各种信息。

【免费下载链接】pytorch-lightning Lightning-AI/pytorch-lightning: PyTorch Lightning 是一个轻量级的高级接口,用于简化 PyTorch 中深度学习模型的训练流程。它抽象出了繁杂的工程细节,使研究者能够专注于模型本身的逻辑和实验设计,同时仍能充分利用PyTorch底层的灵活性。 【免费下载链接】pytorch-lightning 项目地址: https://gitcode.com/gh_mirrors/py/pytorch-lightning

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值