3行代码实现实验可复现:WandB超参数日志让MLOps不再是难题
【免费下载链接】MLOps-Basics 项目地址: https://gitcode.com/GitHub_Trending/ml/MLOps-Basics
你是否还在为复现同事的实验结果焦头烂额?训练代码一致却跑出不同指标?本文将通过MLOps-Basics项目中的week_1_wandb_logging模块,教你用WandB(Weights & Biases)构建可追溯的机器学习实验系统,3行核心代码解决90%的实验混乱问题。
读完本文你将掌握:
- 超参数自动记录与版本管理
- 训练指标可视化追踪技巧
- 实验结果一键复现的标准流程
- 错误预测样本自动抓取方法
为什么实验可复现性如此重要?
机器学习实验中,即使使用相同的代码,也可能因环境配置、随机种子、数据版本等差异导致结果截然不同。MLOps-Basics项目通过模块化设计展示了如何解决这一痛点,其中week_1_wandb_logging模块专门演示了WandB在实验追踪中的应用。
图1:MLOps-Basics项目的基础工作流,WandB处于实验追踪的核心位置
环境准备:3步搭建可复现实验环境
首先需要配置一致的开发环境,项目提供了标准化的环境配置流程:
- 创建虚拟环境:
conda create --name project-setup python=3.8
conda activate project-setup
- 安装依赖包:
pip install -r week_1_wandb_logging/requirements.txt
- 验证安装:
python -c "import wandb; print(wandb.__version__)"
详细环境配置可参考week_1_wandb_logging/README.md中的"Requirements"部分。
核心实现:WandB日志记录三要素
1. 初始化WandB实验
在week_1_wandb_logging/train.py中,通过WandbLogger与PyTorch Lightning无缝集成:
wandb_logger = WandbLogger(project="MLOps Basics", entity="raviraja")
trainer = pl.Trainer(
max_epochs=1,
logger=wandb_logger,
callbacks=[checkpoint_callback, SamplesVisualisationLogger(cola_data), early_stopping_callback],
log_every_n_steps=10,
deterministic=True
)
这段代码实现了三个关键功能:
- 关联到指定的WandB项目
- 配置日志记录频率
- 启用确定性训练模式(保证结果可复现)
2. 超参数自动记录
WandB会自动捕捉训练代码中的超参数,例如在model.py中定义的模型参数:
class ColaModel(pl.LightningModule):
def __init__(self, model_name="google/bert-base-cased", lr=2e-5):
super(ColaModel, self).__init__()
self.save_hyperparameters() # 自动记录超参数
self.bert = BertModel.from_pretrained(model_name)
self.classifier = nn.Linear(self.bert.config.hidden_size, 2)
# ...
通过self.save_hyperparameters()方法,所有初始化参数会自动上传到WandB dashboard,无需手动记录。
3. 训练指标实时追踪
在训练循环中,PyTorch Lightning的log方法会自动将指标发送到WandB:
def training_step(self, batch, batch_idx):
loss, logits, labels = self._shared_step(batch)
self.log("train/loss", loss, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
loss, logits, labels = self._shared_step(batch)
self.log("valid/loss", loss, prog_bar=True, logger=True)
return loss
这些指标会实时显示在WandB dashboard中,形成直观的训练曲线。
图2:WandB提供的实验仪表板,可同时对比多个实验结果
高级功能:错误样本可视化
项目实现了一个自定义回调SamplesVisualisationLogger,用于自动抓取并记录错误预测样本:
class SamplesVisualisationLogger(pl.Callback):
def __init__(self, datamodule):
super().__init__()
self.datamodule = datamodule
def on_validation_end(self, trainer, pl_module):
val_batch = next(iter(self.datamodule.val_dataloader()))
sentences = val_batch["sentence"]
outputs = pl_module(val_batch["input_ids"], val_batch["attention_mask"])
preds = torch.argmax(outputs.logits, 1)
labels = val_batch["label"]
df = pd.DataFrame(
{"Sentence": sentences, "Label": labels.numpy(), "Predicted": preds.numpy()}
)
wrong_df = df[df["Label"] != df["Predicted"]]
trainer.logger.experiment.log(
{
"examples": wandb.Table(dataframe=wrong_df, allow_mixed_types=True),
"global_step": trainer.global_step,
}
)
这段代码在week_1_wandb_logging/train.py的13-37行定义,实现了错误样本的自动捕捉和表格化展示。
运行与监控:完整实验流程
启动训练
执行训练命令:
python week_1_wandb_logging/train.py
首次运行时,WandB会提示你登录账号(免费注册即可)。训练完成后,会在终端输出类似以下信息:
wandb: Synced 5 W&B file(s), 4 media file(s), 3 artifact file(s) and 0 other file(s)
wandb: Synced proud-mountain-77: https://wandb.ai/raviraja/MLOps%20Basics/runs/3vp1twdc
查看实验结果
点击输出的链接,即可在浏览器中查看完整的实验记录,包括:
- 训练/验证损失曲线
- 超参数配置
- 错误预测样本表格
- 模型 checkpoint
图3:WandB提供的实验看板,展示完整的训练过程和结果分析
最佳实践:提升实验可复现性的5个技巧
- 设置随机种子:在训练代码中固定随机种子,确保每次运行从相同初始状态开始
- 使用确定性算法:如代码中
deterministic=True配置 - 版本化数据:后续模块week_3_dvc展示了如何用DVC管理数据版本
- 记录环境信息:通过
wandb.save("requirements.txt")保存依赖版本 - 自动化实验流程:结合week_6_github_actions实现CI/CD流水线
总结与后续学习路径
通过WandB日志记录,MLOps-Basics项目展示了如何解决机器学习实验的可复现性问题。这一模块是整个MLOps流程的基础,后续的Docker容器化(week_5_docker)、模型部署(week_8_serverless)等环节都建立在可复现的实验基础之上。
图4:MLOps-Basics项目涵盖的完整MLOps流程,从数据处理到模型部署
下一篇我们将介绍如何使用Hydra进行配置管理,进一步提升实验的灵活性和可维护性。收藏本文,关注项目更新,不错过后续的MLOps实践指南!
本文代码示例均来自MLOps-Basics项目的week_1_wandb_logging模块,完整实现可通过以下命令获取:
git clone https://gitcode.com/GitHub_Trending/ml/MLOps-Basics
【免费下载链接】MLOps-Basics 项目地址: https://gitcode.com/GitHub_Trending/ml/MLOps-Basics
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考






