wandb在pytorch lightning中的使用

本文介绍了如何在PyTorch Lightning中利用Wandb进行模型训练,包括超参数保存、梯度、参数直方图、模型拓扑、metric记录、图像和文本数据的跟踪,以及在多GPU环境下的正确配置。通过一个ResNet图像分类示例,详细展示了关键操作和代码实现。


在wandb和pytorch lightning的官方文档中都有 在pytorch lightning中使用wandb的使用方法,链接分别为 https://docs.wandb.ai/guides/integrations/lightning#log-gradients-parameter-histogram-and-model-topologyhttps://pytorch-lightning.readthedocs.io/en/latest/extensions/generated/pytorch_lightning.loggers.WandbLogger.html?highlight=wandblogger。本笔记先以官方文档中内容为主进行使用方法的讲解,然后结合一个简单的使用ResNet网络图像分类的例子进行解析,部分没有提到的内容可自行在上述官方文档查询学习

使用前提

使用之前,如笔记wandb在pytorch中的使用记录中记录,需要先安装wandb,注册账号后使用密钥链接账号

使用解析

初始化

使用方面,pytorch lightning框架中提供了WandbLogger接口,在完成wandb安装后,直接使用WandbLogger接口提供的各种方法就能进行各类数据的记录,与单独使用wandb功能方面一致。使用时,只需先初始化一个WandbLogger类的对象,然后在定义trainer时将其作为logger传入即可,如下代码所示

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

wandb_logger = WandbLogger()  # 初始化个WandbLogger对象
trainer = Trainer(logger=wandb_logger)  # 初始化trainer

进行个WandbLogger类实例化时有以下常用参数

参数 描述
project 定义要登录的 wandb 项目
name 对当前wandb运行记录设置名称
log_model 设置记录参数范围:如果 log_model=“all” 记录所有模型,如果 log_model=True 在训练结束时记录
save_dir 保存数据的路径

模型超参数保存

初始化时将WandbLogger的实例化对象设置给trainer后,继承pl.LightningModule进行网络定义时,在__int__函数中直接调用self.save_hyperparameters()即可进行超参数保存

class LitModule(LightningModule):
    def __init__(self, *args, **kwarg):
        self.save_hyperparameters()

记录其他配置参数

# add one parameter
wandb_logger.experiment.config["key"] = value

# add multiple parameters
wandb_logger.experiment.config.update({
   
   key1: val1, key2: val2})

# use directly wandb module
wandb.config["key"] = value
wandb.config.update()

记录梯度、参数直方图和模型拓扑

如wandb中一样,调用WandbLogger.watch()方法进行设置,如代码案例main()函数所示,该方法可以进行如下几种设置

# log gradients and model topology
wandb_logger.watch(model)

# log gradients, parameter histogram and model topology
wandb_logger.watch(model, log="all")

# change log frequency of gradients and parameters (100 steps by default)
wandb_logger.watch(model, log_freq=500)

# do not log graph (in case of errors)
wandb_logger.watch(model, log_graph=False)

记录metric

可以通过在 LightningModule 中调用 self.log(‘my_metric_name’, metric_vale) 将指标记录到 W&B,例如在 代码案例training_step()validation_step() 方法中

记录metric的最小值/最大值

使用 wandb 的 define_metric 函数,可以定义是否希望 wandb 汇总指标显示该指标的最小值、最大值、平均值或最佳值。如果未使用 define_metric,那么最后记录的值将出现在您的摘要指标中。有关更多信息,请参阅此处的 define_metric 参考文档和此处的指南

class My_LitModule(LightningModule):
    ...
    
    def validation_step(self, batch, batch_idx):
        if trainer.global_step == 0: 
            wandb.define_metric('val_accuracy', summary='max')  # 显示val_accuracy指标的最大值
        
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)
        return preds

记录图像、文本等

WandbLogger 具有用于记录媒体的 log_image、log_text 和 log_table 方法;也可以直接调用 wandb.log 或 trainer.logger.experiment.log 来记录其他媒体类型,例如音频、分子、点云、3D 对象等
注意:在trainer中使用 wandb.log 或 trainer.logger.experiment.log 时,请确保在被传递的字典中也包含“global_step”:trainer.global_step。这样,可以将当前记录的信息与通过其他方法记录的信息对齐。

记录图像

# using tensors, numpy arrays or PIL images
wandb_logger.log_image(key="samples", images=[img1, img2])

# adding captions
wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"])

# using file path
wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])

# using .log in the trainer
trainer.logger.experiment.log({
   
   
    "samples": [wandb.Image(img, caption=caption) 
    for (img, caption) in my_images]
})

记录文本

# da
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值