PyTorch Lightning 日志系统完全指南

PyTorch Lightning 日志系统完全指南

概述

PyTorch Lightning 提供了一个强大而灵活的日志系统,可以帮助开发者轻松跟踪和可视化模型训练过程中的各种指标。本文将全面介绍 PyTorch Lightning 的日志功能,包括内置日志记录器的使用、自定义日志记录、日志频率控制以及高级日志技巧。

内置日志记录器

PyTorch Lightning 支持多种流行的日志记录器,包括:

  • TensorBoardLogger (默认)
  • CSVLogger
  • MLFlowLogger
  • NeptuneLogger
  • WandbLogger
  • CometLogger

这些日志记录器会自动绘制"全局步数 vs 周期"图表,部分记录器还会提供额外的可视化功能。

基本使用示例

from lightning.pytorch import Trainer

# 使用默认的TensorBoard记录器,日志保存在lightning_logs/目录
trainer = Trainer()

# 使用自定义记录器
from lightning.pytorch import loggers as pl_loggers

tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")
trainer = Trainer(logger=tb_logger)

# 同时使用多个记录器
tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")
comet_logger = pl_loggers.CometLogger(save_dir="logs/")
trainer = Trainer(logger=[tb_logger, comet_logger])

从LightningModule中记录日志

PyTorch Lightning 提供了自动和手动两种日志记录方式。

自动日志记录

使用log()log_dict()方法可以方便地记录指标:

def training_step(self, batch, batch_idx):
    self.log("my_metric", x)  # 记录单个指标
    self.log_dict({"acc": acc, "recall": recall})  # 同时记录多个指标

log()方法提供了丰富的控制选项:

def training_step(self, batch, batch_idx):
    self.log("my_loss", loss, 
             on_step=True,    # 记录每一步
             on_epoch=True,   # 记录整个周期
             prog_bar=True,   # 显示在进度条
             logger=True)     # 记录到日志文件

日志行为默认设置

不同钩子函数中的默认日志行为:

| 钩子函数 | on_step | on_epoch | |---------|---------|----------| | 训练周期相关钩子 | False | True | | 训练批次相关钩子 | True | False | | 验证相关钩子 | False | True |

手动记录非标量数据

对于直方图、图像等非标量数据,可以直接使用记录器的实验接口:

def training_step(self):
    # 获取当前记录器
    tensorboard = self.logger.experiment
    tensorboard.add_image("sample", image)
    tensorboard.add_histogram("weights", model.weights)

自定义日志记录器

你可以通过继承Logger类创建自定义记录器:

from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment
from lightning.pytorch.utilities import rank_zero_only

class MyLogger(Logger):
    @property
    def name(self):
        return "MyLogger"

    @property
    def version(self):
        return "0.1"

    @rank_zero_only
    def log_hyperparams(self, params):
        # 记录超参数
        pass

    @rank_zero_only
    def log_metrics(self, metrics, step):
        # 记录指标
        pass

控制日志频率

日志记录频率

默认每50步记录一次日志,可以通过log_every_n_steps调整:

trainer = Trainer(log_every_n_steps=10)  # 每10步记录一次

日志写入频率

不同记录器有不同的刷新频率,例如CSVLogger可以通过flush_logs_every_n_steps设置。

进度条集成

可以将任何指标添加到进度条:

def training_step(self, batch, batch_idx):
    self.log("my_loss", loss, prog_bar=True)

自定义进度条内容:

from lightning.pytorch.callbacks.progress import TQDMProgressBar

class CustomProgressBar(TQDMProgressBar):
    def get_metrics(self, *args, **kwargs):
        items = super().get_metrics(*args, **kwargs)
        items.pop("v_num", None)  # 移除版本号显示
        return items

控制台日志配置

可以调整Lightning的控制台日志级别:

import logging

# 设置全局日志级别
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)

# 模块级日志配置
logger = logging.getLogger("lightning.pytorch.core")
logger.addHandler(logging.FileHandler("core.log"))

超参数记录

检查点会自动保存超参数:

lightning_checkpoint = torch.load(filepath)
hyperparams = lightning_checkpoint["hyper_parameters"]

部分记录器支持超参数可视化,例如TensorBoard的hparams标签页:

# 使用默认指标
def validation_step(self, batch, batch_idx):
    self.log("hp_metric", some_scalar)

# 使用自定义或多指标
def on_train_start(self):
    self.logger.log_hyperparams(self.hparams, 
                              {"hp/metric_1": 0, "hp/metric_2": 0})

def validation_step(self, batch, batch_idx):
    self.log("hp/metric_1", some_scalar_1)
    self.log("hp/metric_2", some_scalar_2)

远程文件系统支持

PyTorch Lightning 支持将日志保存到多种远程文件系统,包括云存储服务。具体配置可参考相关文档。

总结

PyTorch Lightning 的日志系统提供了从简单到高级的全面日志功能,无论是初学者还是高级用户都能找到适合自己需求的日志方案。通过合理配置日志系统,可以大大提高模型训练和调试的效率。

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

窦恺墩

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值