文章目录
在wandb和pytorch lightning的官方文档中都有 在pytorch lightning中使用wandb的使用方法,链接分别为 https://docs.wandb.ai/guides/integrations/lightning#log-gradients-parameter-histogram-and-model-topology和 https://pytorch-lightning.readthedocs.io/en/latest/extensions/generated/pytorch_lightning.loggers.WandbLogger.html?highlight=wandblogger。本笔记先以官方文档中内容为主进行使用方法的讲解,然后结合一个简单的使用ResNet网络图像分类的例子进行解析,部分没有提到的内容可自行在上述官方文档查询学习
使用前提
使用之前,如笔记wandb在pytorch中的使用记录中记录,需要先安装wandb,注册账号后使用密钥链接账号
使用解析
初始化
使用方面,pytorch lightning框架中提供了WandbLogger接口,在完成wandb安装后,直接使用WandbLogger接口提供的各种方法就能进行各类数据的记录,与单独使用wandb功能方面一致。使用时,只需先初始化一个WandbLogger类的对象,然后在定义trainer时将其作为logger传入即可,如下代码所示
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
wandb_logger = WandbLogger() # 初始化个WandbLogger对象
trainer = Trainer(logger=wandb_logger) # 初始化trainer
进行个WandbLogger类实例化时有以下常用参数
| 参数 | 描述 |
|---|---|
| project | 定义要登录的 wandb 项目 |
| name | 对当前wandb运行记录设置名称 |
| log_model | 设置记录参数范围:如果 log_model=“all” 记录所有模型,如果 log_model=True 在训练结束时记录 |
| save_dir | 保存数据的路径 |
模型超参数保存
初始化时将WandbLogger的实例化对象设置给trainer后,继承pl.LightningModule进行网络定义时,在__int__函数中直接调用self.save_hyperparameters()即可进行超参数保存
class LitModule(LightningModule):
def __init__(self, *args, **kwarg):
self.save_hyperparameters()
记录其他配置参数
# add one parameter
wandb_logger.experiment.config["key"] = value
# add multiple parameters
wandb_logger.experiment.config.update({
key1: val1, key2: val2})
# use directly wandb module
wandb.config["key"] = value
wandb.config.update()
记录梯度、参数直方图和模型拓扑
如wandb中一样,调用WandbLogger.watch()方法进行设置,如代码案例中main()函数所示,该方法可以进行如下几种设置
# log gradients and model topology
wandb_logger.watch(model)
# log gradients, parameter histogram and model topology
wandb_logger.watch(model, log="all")
# change log frequency of gradients and parameters (100 steps by default)
wandb_logger.watch(model, log_freq=500)
# do not log graph (in case of errors)
wandb_logger.watch(model, log_graph=False)
记录metric
可以通过在 LightningModule 中调用 self.log(‘my_metric_name’, metric_vale) 将指标记录到 W&B,例如在 代码案例中training_step() 或 validation_step() 方法中
记录metric的最小值/最大值
使用 wandb 的 define_metric 函数,可以定义是否希望 wandb 汇总指标显示该指标的最小值、最大值、平均值或最佳值。如果未使用 define_metric,那么最后记录的值将出现在您的摘要指标中。有关更多信息,请参阅此处的 define_metric 参考文档和此处的指南
class My_LitModule(LightningModule):
...
def validation_step(self, batch, batch_idx):
if trainer.global_step == 0:
wandb.define_metric('val_accuracy', summary='max') # 显示val_accuracy指标的最大值
preds, loss, acc = self._get_preds_loss_accuracy(batch)
# Log loss and metric
self.log('val_loss', loss)
self.log('val_accuracy', acc)
return preds
记录图像、文本等
WandbLogger 具有用于记录媒体的 log_image、log_text 和 log_table 方法;也可以直接调用 wandb.log 或 trainer.logger.experiment.log 来记录其他媒体类型,例如音频、分子、点云、3D 对象等
注意:在trainer中使用 wandb.log 或 trainer.logger.experiment.log 时,请确保在被传递的字典中也包含“global_step”:trainer.global_step。这样,可以将当前记录的信息与通过其他方法记录的信息对齐。
记录图像
# using tensors, numpy arrays or PIL images
wandb_logger.log_image(key="samples", images=[img1, img2])
# adding captions
wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"])
# using file path
wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])
# using .log in the trainer
trainer.logger.experiment.log({
"samples": [wandb.Image(img, caption=caption)
for (img, caption) in my_images]
})
记录文本
# da

本文介绍了如何在PyTorch Lightning中利用Wandb进行模型训练,包括超参数保存、梯度、参数直方图、模型拓扑、metric记录、图像和文本数据的跟踪,以及在多GPU环境下的正确配置。通过一个ResNet图像分类示例,详细展示了关键操作和代码实现。
最低0.47元/天 解锁文章
1053

被折叠的 条评论
为什么被折叠?



