【DL】Keras框架调用多个GPU并行训练,并保存模型

本文介绍了如何在Keras框架中使用多GPU进行数据并行训练的方法,包括配置多GPU环境、实现多GPU并行训练的过程及注意事项。还讨论了在使用ModelCheckpoint时可能遇到的问题及其解决办法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

多GPU训练可以在一定程度上提高训练的效率,Keras在以TensorFlow为底层的时候,可以方便的调用多GPU进行训练。其训练的原理是将一个神经网络在多块GPU中进行初始化,并且将一个batch的数据分配到不同的GPU中反向传播结束后,再使用CPU进行合并更新参数

举个例子:
同时使用两块GPU,batch_size=32的话,Keras会将16个给0号GPU,另外16个给1号GPU,训练结束后将二者拼接为一个32的batch并进行反向传播更新参数。

多GPU其实分为两种使用情况:数据并行和设备并行。

数据并行将目标模型在多个设备上各复制一份,并使用每个设备上的复制品处理整个数据集的不同部分数据。

Keras在 keras.utils.multi_gpu_model 中提供有内置函数,该函数可以产生任意模型的数据并行版本,最高支持在8片GPU上并行。

数据并行利用多块GPU同时训练多个batch数据,运行在每块GPU上的模型为同一个神经网络,网络结构完全一样,并且共享模型参数

设备并行是在不同设备上运行同一个模型的不同部分,当模型含有多个并行结构,例如含有两个分支时,这种方式很适合。

我们大多数时候要用到的都是数据并行

这里就给出数据并行的多GPU训练示例:

from keras.utils.training_utils import multi_gpu_model   #导入keras多GPU函数

model = get_model()
parallel_model = multi_gpu_model(model, gpus=2) # 设置使用2个gpu,该句放在模型compile之前
parallel_model.compile(loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])
hist = parallel_model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs_num, validation_data=(x_test, y_test), verbose=1, callbacks=callbacks)

使用多 gpu 并行训练并使用 ModelCheckpoint() 可能遇到的问题:

  1. 报错(楼主尚未遇到)
  2. 可以正常训练,保存模型,但是,进行测试时,权重混乱,预测失败。(目前遇到过)

当需要使用Keras的checkpoint的功能的时候,使用如下函数。注意:初始化的参数model,就是单个的model,而不是用于多GPU训练的parallel_model。

class ParallelModelCheckpoint(ModelCheckpoint):
    def __init__(self,model,filepath, monitor='val_loss', verbose=0,
                 save_best_only=False, save_weights_only=False,
                 mode='auto', period=1):
        self.single_model = model
        super(ParallelModelCheckpoint,self).__init__(filepath, monitor, verbose,save_best_only, save_weights_only,mode, period)
 
    def set_model(self, model):
        super(ParallelModelCheckpoint,self).set_model(self.single_model)

调用如下:

checkpoint = ParallelModelCheckpoint(model, save_path, monitor='val_auc', verbose=1, save_best_only=True, save_weights_only=False, mode='max')
 
callbacks = [check_point, early_stopping, tbCallBack]
 
model.fit(x_train, y_train,
          batch_size=config.batch_size,
          epochs=config.nb_epoch, 
          verbose=config.verbose, 
          shuffle=True,
          validation_data=(x_valid, y_valid),
          callbacks=callbacks)

总体上本篇博客的思想就是:楼主遇到一个bug,keras框架下,单卡训练模型,保存模型,加载测试模型,正常;但是,双卡训练模型,保存模型正常,在最后测试的时候出现问题,主要是因为使用Keras的checkpoint的功能的时候需要写一个class,上述代码已给出,当然还有其他的解决方案,请看这里,here或者这里点我!!!

参考

keras中文文档
https://keras.io/zh/models/model/

https://www.jianshu.com/p/4203a6435ab5

https://blog.youkuaiyun.com/u012862372/article/details/80367607

https://blog.youkuaiyun.com/ywcpig/article/details/105659637

https://blog.youkuaiyun.com/qq_34564612/article/details/79209965

<think>好的,用户让我详细介绍一下Deeplearning4j。首先,我需要确认Deeplearning4j是什么。根据之前的知识,它是一个面向Java和Scala的开源深度学习库,主要针对企业级应用。用户可能已经了解过Hadoop,现在想转向深度学习,但希望继续使用Java生态系统,所以才会问到Deeplearning4j。 接下来,用户的需求可能是什么?他们可能是在比较不同的深度学习框架,比如TensorFlow或PyTorch,但需要Java环境下的解决方案。也可能是在企业环境中需要集成深度学习模型到现有的Java系统中,比如使用Hadoop处理大数据后,用Deeplearning4j进行模型训练。 用户的身份可能是Java开发者,或者是大数据工程师,希望将深度学习整合到现有的大数据架构中。他们可能更关注框架的分布式训练能力、与Hadoop/Spark的兼容性,以及在生产环境中的部署便利性。 需要详细说明Deeplearning4j的核心功能、架构、生态系统、优缺点以及应用场景。同时要对比其他框架,突出其优势,比如对Java的支持、分布式计算能力、与Hadoop生态的集成等。还要提到它的底层依赖,比如ND4J,类似于NumPy的库,这对于性能优化很重要。 可能用户没有明说的需求是,他们需要知道Deeplearning4j的学习曲线如何,是否有足够的社区支持和文档资源。因此,在回答中可以提到其社区活跃度、文档质量以及企业支持情况。 另外,考虑到用户之前询问过Hadoop,可能需要强调Deeplearning4j与Hadoop/Spark的协同工作能力,比如在YARN上运行分布式训练,或者使用Spark进行数据预处理,这对用户现有的架构可能很重要。 还需要指出Deeplearning4j的不足,比如在学术界不如TensorFlow或PyTorch流行,可能预训练模型较少,社区资源相对有限。但如果是企业环境,特别是Java栈的公司,这些缺点可能不那么重要,反而其与现有系统的整合能力更为关键。 最后,总结部分要强调Deeplearning4j的适用场景,比如企业级应用、需要Java/Scala集成、分布式训练需求等,帮助用户判断是否适合他们的项目。</think>以下是关于 **Deeplearning4j(DL4J)** 的详细介绍,涵盖其设计目标、核心架构、功能特性、应用场景及优缺点: --- ### **一、Deeplearning4j 是什么?** **Deeplearning4j(DL4J)** 是一个开源的 **分布式深度学习框架**,专为 Java 和 Scala 生态系统设计,尤其适用于企业级生产环境。其核心目标是: - 在 JVM(Java 虚拟机)上高效运行深度学习模型。 - 与 Hadoop、Spark 等大数据工具无缝集成。 - 支持分布式训练和多 GPU 加速,满足企业大规模数据需求。 与 TensorFlow、PyTorch 等框架不同,**DL4J 更注重 Java/Scala 生态的兼容性**,适合需要将深度学习模型嵌入现有 Java 应用或大数据管道的场景。 --- ### **二、核心架构与依赖** DL4J 的底层由多个模块协作组成,核心依赖包括: 1. **ND4J(N-Dimensional Arrays for Java)** - 功能类似 NumPy,提供高效的**多维数组计算**,支持 CPU/GPU 加速。 - 是 DL4J 的数学计算引擎,支持张量操作(如矩阵乘法、卷积)。 2. **DataVec** - 数据预处理工具,用于数据清洗、转换和特征工程。 - 支持从 CSV、图像、文本、视频等格式加载数据,转换为 DL4J 可用的格式。 3. **LibND4J** - 基于 C++ 的高性能计算库,为 ND4J 提供底层加速(如 CUDA 支持)。 4. **Arbiter** - 超参数调优工具,支持自动化搜索最优模型配置。 --- ### **三、功能特性** #### **1. 支持的神经网络类型** - **卷积神经网络(CNN)**:图像分类、目标检测。 - **循环神经网络(RNN/LSTM/GRU)**:时间序列预测、自然语言处理。 - **深度强化学习(DQN/A3C)**:游戏 AI、自动化决策。 - **自编码器(Autoencoder)**:数据降维、异常检测。 - **生成对抗网络(GAN)**:图像生成、数据增强。 #### **2. 分布式训练** - **与 Apache Spark 集成**:可在 Spark 集群上分布式训练模型,支持数据并行模型并行。 - **与 Hadoop 兼容**:通过 YARN 管理资源,直接读取 HDFS 上的数据。 #### **3. 生产部署** - 支持将模型导出为 **ONNX 格式**(与其他框架互操作)。 - 提供 **POJO(Plain Old Java Object)** 序列化,可直接嵌入 Java 应用。 - 支持 **TensorFlow/Keras 模型导入**(通过 `keras-import` 模块)。 #### **4. 硬件加速** - 支持 CUDA 和 cuDNN,可利用 NVIDIA GPU 加速计算。 - 多 GPU 训练通过 **NCCL** 库实现高效通信。 --- ### **四、DL4J 生态系统** | **工具/库** | **用途** | |--------------------|--------------------------------------------------------------------------| | **ND4J** | 张量计算引擎,支持 CPU/GPU 加速。 | | **DataVec** | 数据预处理工具,支持数据标准化、缺失值填充、图像增强等。 | | **RL4J** | 强化学习库,实现 DQN、A3C 等算法。 | | **Python4J** | 允许在 Python 中调用 DL4J 模型(通过桥接)。 | | **Skymind** | 商业支持公司,提供企业级解决方案(如模型监控、自动化部署)。 | --- ### **五、DL4J 的优势** 1. **企业级集成** - 与 Java/Scala 生态无缝兼容,适合在 Hadoop、Spark、Kafka 等系统中部署。 - 支持 Spring Boot 等企业框架,便于构建微服务。 2. **分布式计算能力** - 原生支持 Spark 分布式训练,适合处理 TB 级数据。 3. **生产就绪** - 模型可直接部署为轻量级 JAR 包,无需依赖 Python 环境。 4. **灵活性** - 允许用户自定义网络层、损失函数和优化器。 --- ### **六、DL4J 的局限性** 1. **社区规模较小** - 相比 TensorFlow/PyTorch,用户和开发者社区较小,学习资源有限。 2. **学术界使用较少** - 学术界更倾向于 Python 框架,预训练模型(如 BERT、ResNet)较少直接支持 DL4J。 3. **调试复杂性** - JVM 的内存管理和 GPU 配置需要一定经验。 --- ### **七、典型应用场景** 1. **金融风控** - 使用 LSTM 检测信用卡交易异常。 2. **工业物联网(IIoT)** - 基于传感器数据训练 CNN 模型,预测设备故障。 3. **推荐系统** - 在 Spark 集群上分布式训练深度协同过滤模型。 4. **自然语言处理** - 构建企业级文本分类系统(如客服工单自动分类)。 --- ### **八、与其他框架的对比** | **框架** | **特点** | **适用场景** | |-----------------|--------------------------------------------------------------------------|--------------------------| | **Deeplearning4j** | JVM 原生支持,企业级集成,分布式训练 | Java/Scala 生态、大数据集成 | | **TensorFlow** | 生态丰富,支持移动端(TFLite)、浏览器(TensorFlow.js) | 研究到生产的全流程覆盖 | | **PyTorch** | 动态计算图,学术界主流框架,灵活易调试 | 研究原型开发、快速实验 | | **Keras** | 高层 API,简化模型构建,适合快速开发 | 小型项目、教育用途 | --- ### **九、快速入门示例** ```java // 构建一个简单的多层感知机(MLP) MultiLayerConfiguration config = new NeuralNetConfiguration.Builder() .seed(123) .updater(new Adam(0.01)) .activation(Activation.RELU) .weightInit(WeightInit.XAVIER) .list() .layer(new DenseLayer.Builder().nIn(784).nOut(250).build()) .layer(new DenseLayer.Builder().nIn(250).nOut(100).build()) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .activation(Activation.SOFTMAX).nIn(100).nOut(10).build()) .build(); MultiLayerNetwork model = new MultiLayerNetwork(config); model.init(); // 加载数据(MNIST 手写数字数据集) DataSetIterator trainData = new MnistDataSetIterator(64, true, 12345); DataSetIterator testData = new MnistDataSetIterator(64, false, 12345); // 训练模型 for (int i = 0; i < 10; i++) { model.fit(trainData); Evaluation eval = model.evaluate(testData); System.out.println(eval.stats()); } ``` --- ### **十、总结** **Deeplearning4j** 是 **Java/Scala 开发者进入深度学习领域的首选工具**,尤其适合需要与企业大数据系统(如 Hadoop、Spark)集成的场景。尽管在学术界影响力有限,但其生产就绪特性、分布式能力和 JVM 兼容性使其在金融、物联网、企业服务等领域具有独特优势。对于已有 Java 技术栈的团队,DL4J 是避免跨语言开发复杂性的理想选择。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

机器不学习我学习

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值