Determined AI 项目中的 Keras API 使用指南

Determined AI 项目中的 Keras API 使用指南

determined Determined is an open-source machine learning platform that simplifies distributed training, hyperparameter tuning, experiment tracking, and resource management. Works with PyTorch and TensorFlow. determined 项目地址: https://gitcode.com/gh_mirrors/de/determined

概述

在 Determined AI 平台上使用 Keras 进行模型训练时,keras.DeterminedCallback 是一个关键组件。本文将详细介绍如何在 Determined 环境中配置和运行 Keras 模型训练任务,包括分布式训练、检查点保存、TensorBoard 集成等重要功能。

准备工作

入口点配置

Determined 要求通过实验配置文件来启动训练任务。对于 Keras 训练,必须使用 Determined 的 TensorFlow 启动器来包装训练脚本:

entrypoint: >-
  python3 -m determined.launch.tensorflow --
  python3 my_train.py --my-arg...

这个启动器会自动配置正确的 TF_CONFIG 环境变量,无论是否使用分布式训练都能正确处理。

核心组件初始化

分布式上下文和策略

分布式训练需要提前创建 TensorFlow 的 Strategy 对象。使用 Determined 提供的辅助方法可以简化这一过程:

if __name__ == "__main__":
    distributed, strategy = det.core.DistributedContext.from_tf_config()
    with det.core.init(distributed=distributed) as core_context:
        main(core_context, strategy)

这段代码会创建分布式上下文和相应的策略对象,为后续模型构建做好准备。

模型构建与训练

模型构建

在分布式策略的作用域内构建和编译模型:

def main(core_context, strategy):
    with strategy.scope():
        model = my_build_model()
        model.compile(...)

这种方式确保模型能够适应分布式训练环境。

DeterminedCallback 创建

DeterminedCallback 是连接 Keras 训练与 Determined 平台的核心组件,它负责:

  1. 报告训练和测试指标
  2. 保存检查点并上传到存储
  3. 处理暂停/恢复信号

创建示例:

info = det.get_cluster_info()
assert info and info.task_type == "TRIAL", "此示例仅在集群上作为试验运行"

det_cb = det.keras.DeterminedCallback(
    core_context,
    checkpoint=info.latest_checkpoint,
    continue_id=info.trial.trial_id,
)

关键参数说明:

  • checkpoint: 指定从哪个检查点恢复训练
  • continue_id: 决定如何处理检查点(从零开始或继续训练)

数据加载

数据加载与常规 Keras 训练相同,但需要注意:

  1. 确保数据加载代码能在容器环境中运行
  2. 分布式训练时需要考虑数据分片策略

TensorBoard 集成

Determined 提供了增强版的 TensorBoard 回调,能够自动将指标上传到检查点存储:

tb_cb = det.keras.TensorBoard(core_context, ...)

使用方法与标准 Keras TensorBoard 回调相同,只是多了一个 core_context 参数。

启动训练

最后,将回调传递给 model.fit()

model.fit(
    ...,
    callbacks=[det_cb, tb_cb],
)

最佳实践

  1. 检查点管理:合理设置 checkpointcontinue_id 参数,确保训练可以正确恢复
  2. 分布式训练:确保数据加载逻辑正确处理分片
  3. 资源利用:监控 GPU 使用情况,调整批量大小以获得最佳性能
  4. 日志记录:利用 Determined 的日志系统记录关键训练指标

总结

通过本文介绍的方法,开发者可以轻松地将 Keras 模型训练任务迁移到 Determined 平台上,充分利用其分布式训练、检查点管理和监控功能。DeterminedCallback 的引入大大简化了与平台的集成工作,使开发者能够专注于模型本身而非基础设施管理。

对于更复杂的用例,建议参考 Determined 的完整文档,了解高级配置选项和性能调优技巧。

determined Determined is an open-source machine learning platform that simplifies distributed training, hyperparameter tuning, experiment tracking, and resource management. Works with PyTorch and TensorFlow. determined 项目地址: https://gitcode.com/gh_mirrors/de/determined

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

明会泽Irene

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值