DL4J中文文档/调优与训练/早停

早停是一种在训练神经网络过程中防止过拟合的技术。通过在每个epoch结束时评估模型在测试集上的性能,如果性能未在预设的epoch数内提升,则停止训练。此策略有助于自动确定训练的最佳终止点。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

什么是早停?

在训练神经网络时,需要对使用的设置(超参数)作出许多决策,以便获得良好的性能。一旦这样的超参数是训练epochs的数目:也就是说,数据集(epochs)的完整传递应该有几次?如果我们使用太少的epochs,我们可能欠拟合(即,不能从训练数据中学习我们能学的所有);如果我们使用太多的epochs,我们可能过拟合(即,在训练数据中拟合“噪声”,而不是信号)。
早期停止尝试删除手动设置该值的需要。它也可以被认为是一种正则化方法(如L1/L2权重衰减和丢弃),因为它可以阻止网络过拟合。

早停的思想相对简单:

  • 将数据分割成训练集和测试集
  • 在每个epoch的末尾(或每N个epoch):
    • 评估测试集上的网络性能
    • 如果网络性能超过以前的最佳模型:在当前epoch保存网络的副本
  • 作为最终模型,具有最佳测试集性能的模型

下面用图形显示:

Early Stopping

最好的模型是在垂直虚线时保存的模型,即在测试集上具有最高准确率的模型。
使用DL4J的早停功能需要你提供一些配置选项:

  • 得分计算器,如多层网络的DataSetLossCalculator(JavaDocSource Code)或计算图的DataSetLossCalculatorCG (JavaDocSource Code)。用于在每个epoch进行计算(例如:测试集上的损失函数值或测试集上的准确率)
  • 我们想要计算分数函数的频率(默认值:每个epoch)
  • 一个或多个终止条件,它告诉训练过程何时停止。停止条件有两类:
    • Epoch终止条件:每N个epoch评估
    • 迭代终止条件: 每个小批量评估一次 
  • 一个模型保存器,它定义了如何保存模型。

例如,在epoch终止条件最大为30epoch、最大为20分钟的训练时间的情况下,计算每个epoch的得分,并将中间结果保存到磁盘:


MultiLayerConfiguration myNetworkConfiguration = ...;
DataSetIterator myTrainData = ...;
DataSetIterator myTestData = ...;

EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
		.epochTerminationConditions(new MaxEpochsTerminationCondition(30))
		.iterationTerminationConditions(new MaxTimeIterationTerminationCondition(20, TimeUnit.MINUTES))
		.scoreCalculator(new DataSetLossCalculator(myTestData, true))
        .evaluateEveryNEpochs(1)
		.modelSaver(new LocalFileModelSaver(directory))
		.build();

EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf,myNetworkConfiguration,myTrainData);

//执行早停训练
EarlyStoppingResult result = trainer.fit();

//打印结果
System.out.println("Termination reason: " + result.getTerminationReason());
System.out.println("Termination details: " + result.getTerminationDetails());
System.out.println("Total epochs: " + result.getTotalEpochs());
System.out.println("Best epoch number: " + result.getBestModelEpoch());
System.out.println("Score at best epoch: " + result.getBestModelScore());

//得到最佳模型:
MultiLayerNetwork bestModel = result.getBestModel();

你还可以实现自己的迭代和epoch终止条件。

早停并行包装器

上面描述的早停实现将仅用于单个设备。然而,EarlyStoppingParallelTrainer提供与早期停止类似的功能,并允许你为多个CPU或GPU进行优化。EarlyStoppingParallelTrainer将你的模型包装在ParallelWrapper类中,并执行本地化的分布式训练。

请注意,EarlyStoppingParallelTrainer并不支持作为其单个设备使用时所具有的所有功能。它不是UI兼容的,可能无法与复杂的迭代监听器一起工作。这是由于模型是在后台分发和复制的机置引起的。

 

API


AutoencoderScoreCalculator 自编码器得分计算器

[源码]

一个多层网络或单个计算图的得分函数


ClassificationScoreCalculator 分类得分计算器

[源码]

作为多层网络和计算图的准确率、F1得分等。


DataSetLossCalculator 数据集损失计算器

[源码]

计算给定数据集(通常是测试集)上的得分(损失函数值)


DataSetLossCalculatorCG

[源码]

给定一个 DataSetIterator: 在该数据集上计算模型的总损失。通常用于计算测试集上的损失。


ROCScoreCalculator

[源码]

计算多层网络或计算图的ROC AUC(ROC曲线下的面积)或AUCPR(精确召回曲线下的区域)


RegressionScoreCalculator

[源码]

在测试集上计算网络(MultiLayerNetwork或ComputationGraph)的回归分数


VAEReconErrorScoreCalculator

[源码]

多层网络或计算图的变分自编码器重建误差的得分函数。VariationalAutoencoder 层必须是网络中的第一层。

VAEReconErrorScoreCalculator(RegressionEvaluation.Metric metric, DataSetIterator iterator)

public VAEReconErrorScoreCalculator(RegressionEvaluation.Metric metric, DataSetIterator iterator)

重建误差的构造函数

  • 参数 metric
  • 参数 iterator

VAEReconProbScoreCalculator

[源码]

用于多层网络或计算图的变分自编码器重建概率或重建对数概率的得分计算器。VariationalAutoencoder 层必须是网络中的第一层。

VAEReconProbScoreCalculator(DataSetIterator iterator, int reconstructionProbNumSamples, boolean logProb) 

public VAEReconProbScoreCalculator(DataSetIterator iterator, int reconstructionProbNumSamples, boolean logProb) 

平均重建概率的构造函数

  • 参数 iterator 迭代器
  • 参数 reconstructionProbNumSamples 示例数量。 详见 {- link VariationalAutoencoder#reconstructionProbability(INDArray, int)} 
  • 参数 logProb 如果是 true: 计算 (负)对数概率。 false: 概率

BestScoreEpochTerminationCondition

[源码]

由 Sadat Anwar 创建于 3/26/16.

一旦达到预期分数,就停止训练。通常,如果当前分数低于初始化分数,则停止。如果希望一旦分数增加就停止训练,则将定义的分数设置lesserBetter设为false(请随意给该标志起一个更好的名称)

initialize

public void initialize() 
  • 弃用 “lessBetter” 参数不再使用

ScoreImprovementEpochTerminationCondition

[源码]

终止训练,如果最好的模型得分没有在N个epoch内得到改善

 

 

 

### 使用NDArray的原因 NDArray 是一种多维数组的数据结构,广泛应用于机器学习和深度学习领域。它提供了高效的内存管理和丰富的操作方法来处理大规模数据集。相比于传统的 Java 数据结构(如 `List<List<Float>>`),NDArray 提供了更简洁、高效的方式来进行张量运算[^1]。 在 Java 中实现类似的深度学习功能时,如果手动构建 N 维数组并动态插入元素,会显著增加开发复杂度和运行开销。而 NDArray 的设计初衷正是为了简化这一过程,并提供化后的性能支持。 --- ### 关于 DL4J 和 Java 的相关内容 #### 什么是 DL4J? DL4J(DeepLearning4J)是一个基于 JVM 的开源深度学习框架,专为分布式环境下的高性能计算而设计。它的目标是让开发者能够轻松地将深度学习技术集成到现有的 Java 或 Scala 应用程序中[^4]。 #### 官方文档教程资源 官方文档是最权威的学习资料之一,涵盖了从基础概念到高级应用的全面指导。以下是几个重要的参考资料链接: - **DL4J 官网**: https://deeplearning4j.org/ - **GitHub 存储库**: https://github.com/eclipse/deeplearning4j - **API 文档**: http://nd4j.org/api/ 通过阅读这些文档,可以快速掌握如何安装依赖项以及配置项目所需的 TensorFlow 原生库文件。 #### 示例代码:创建简单的神经网络模型 下面展示了一个基本的例子,演示如何利用 DL4J 构建一个多层感知器 (MLP) 来解决分类问题: ```java import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; public class SimpleNeuralNetwork { public static void main(String[] args) { int numInputs = 784; // 输入维度 int numOutputs = 10; // 输出类别数 NeuralNetConfiguration.ListBuilder listBuilder = new NeuralNetConfiguration.Builder() .seed(123L) .updater(new Sgd(0.1)) .list(); listBuilder.layer( new DenseLayer.Builder().nIn(numInputs).nOut(500) .activation(Activation.RELU) .build()); listBuilder.layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .activation(Activation.SOFTMAX) .nIn(500).nOut(numOutputs).build()); DataSet dataSet = new DataSet(Nd4j.rand(10, numInputs), Nd4j.eye(10)); System.out.println(dataSet); } } ``` 此代码片段定义了一种两层前馈神经网络架构——隐藏层采用 ReLU 激活函数;输出层则使用 Softmax 函数配合交叉熵损失进行训练。 --- ### §相关问题§ 1. 如何在 Maven 项目中引入 DL4J 及其相关依赖? 2. 利用 ND4J 进行矩阵乘法的具体实现方式是什么? 3. 在实际生产环境中部署 DL4J 模型有哪些注意事项? 4. 是否存在其他适合 Java 开发者的深度学习框架作为替代方案? 5. 性能方面,怎样提升 DL4J 训练大型数据集的速度?
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值