java深度学习之DJL创建NiN网络

1、NiN块由一个卷积层和两个1×1卷积层组成。

构建NiN块代码如下


    public static SequentialBlock niNBlock(int numChannels, Shape kernelShape,
                                           Shape strideShape, Shape paddingShape) {

        SequentialBlock tempBlock = new SequentialBlock();
        // numChannels  通道数 滤波器层数
        //kernelShape 卷积核大小
        //strideShape 步幅
        // paddingShape  填充大小
        tempBlock.add(Conv2d.builder()
                .setKernelShape(kernelShape)
                .optStride(strideShape)
                .optPadding(paddingShape)
                .setFilters(numChannels)
                .build())
                .add(Activation::relu)
                .add(Conv2d.builder()
                        .setKernelShape(new Shape(1, 1))
                        .setFilters(numChannels)
                        .build())
                .add(Activation::relu)
                .add(Conv2d.builder()
                        .setKernelShape(new Shape(1, 1))
                        .setFilters(numChannels)
                        .build())
                .add(Activation::relu);

        return tempBlock;
    }

2、构建模型网络代码如下

 System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");
        SequentialBlock block = new SequentialBlock();
        //构建NiN网络
        block.add(niNBlock(96, new Shape(11, 11), new Shape(4, 4), new Shape(0, 0)))
                .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
                .add(niNBlock(256, new Shape(5, 5), new Shape(1, 1), new Shape(2, 2)))
                .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
                .add(niNBlock(384, new Shape(3, 3), new Shape(1, 1), new Shape(1, 1)))
                .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
                //构建 Dropout 层
                .add(Dropout.builder().optRate(0.5f).build())
                // There are 10 label classes
                .add(niNBlock(10, new Shape(3, 3), new Shape(1, 1), new Shape(1, 1)))
                // The global average pooling layer automatically sets the window shape
                // to the height and width of the input 平均池化层
                .add(Pool.globalAvgPool2dBlock())
                // Transform the four-dimensional output into two-dimensional output
                // with a shape of (batch size, 10)
                .add(Blocks.batchFlattenBlock());
        //学习率
        float lr = 0.1f;
        Model model = Model.newInstance("cnn");
        model.setBlock(block);

后面就是准备数据 和训练

整体代码如下

package com.example.demo.djl;

import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.basicdataset.FashionMnist;
import ai.djl.basicdataset.ImageFolder;
import ai.djl.metric.Metrics;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.norm.Dropout;
import ai.djl.nn.pooling.Pool;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.ArrayDataset;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.tracker.Tracker;
import ai.djl.translate.Pipeline;
import ai.djl.translate.TranslateException;
import com.example.demo.djl.covid19.Covid19Models;
import com.example.demo.djl.covid19.Covid19Training;
import org.apache.commons.lang3.ArrayUtils;
import tech.tablesaw.api.DoubleColumn;
import tech.tablesaw.api.StringColumn;
import tech.tablesaw.api.Table;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class NiNTest {

    public static void main(String[] args) throws IOException, ModelException, TranslateException {
        System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");
        // 设置模型存放目录
        Path modelDir = Paths.get("nin");
        //学习率
        float lr = 0.05f;

        Loss loss = Loss.softmaxCrossEntropyLoss();

        Tracker lrt = Tracker.fixed(lr);
        Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();

        DefaultTrainingConfig config = new DefaultTrainingConfig(loss).optOptimizer(sgd) // Optimizer (loss function)
                .addEvaluator(new Accuracy()) // Model Accuracy
                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging

        try (NDManager manager = NDManager.newBaseManager(); Model model = getNiNModel(); Trainer trainer = model.newTrainer(config)) {
            ///从一个均匀分布[low,high)中随机采样,注意定义域是左闭右开,即包含low,不包含high.
            //
            //参数介绍:
            //
            //    low: 采样下界,float类型,默认值为0;
            //    high: 采样上界,float类型,默认值为1;
            //    size: 输出样本数目,为int或元组(tuple)类型,例如,size=(m,n,k), 则输出m*n*k个样本,缺省时输出1个值。
            Block block = model.getBlock();

            NDArray X = manager.randomUniform(0f, 1.0f, new Shape(1, 3, 224, 224));

            trainer.initialize(X.getShape());

            Shape currentShape = X.getShape();

            for (int i = 0; i < block.getChildren().size(); i++) {

                Shape[] newShape = block.getChildren().get(i).getValue().getOutputShapes(manager, new Shape[]{currentShape});
                currentShape = newShape[0];
                //获取每一层的输出层的Shape
                System.out.println(block.getChildren().get(i).getKey() + " layer output : " + currentShape);
            }
            // 批训练大小
            int batchSize = 128;
            //训练次数
            int numEpochs = 30;
            //训练损失值
            double[] trainLoss;
            //测试正确率
            double[] testAccuracy;
            double[] epochCount;
            //训练正确率
            double[] trainAccuracy;

            epochCount = new double[numEpochs];

            for (int i = 0; i < epochCount.length; i++) {
                epochCount[i] = i + 1;
            }
            //FashionMnist  数据是 28*28 的灰度图 通道是1  那么第89 行应该改改为   NDArray X = manager.randomUniform(0f, 1.0f, new Shape(1, 1, 224, 224)); 单通道
//        FashionMnist trainIter = FashionMnist.builder()
//                .optPipeline(new Pipeline().add(new Resize(224)).add(new ToTensor()))
//                .optUsage(Dataset.Usage.TRAIN)
//                .setSampling(batchSize, true)
//                .build();
//
//        FashionMnist testIter = FashionMnist.builder()
//                .optPipeline(new Pipeline().add(new Resize(224)).add(new ToTensor()))
//                .optUsage(Dataset.Usage.TEST)
//                .setSampling(batchSize, true)
//                .build();
//
//        trainIter.prepare();
//        testIter.prepare();
            //使用covid-19 x-ray图片
            ImageFolder dataset = Covid19Training.initDataset("D:\\covid19dataset\\COVID-19 Radiography Database\\");
            // 设置训练数据和验证数据
            RandomAccessDataset[] datasets = dataset.randomSplit(8, 2);

            RandomAccessDataset trainIter = datasets[0];

            RandomAccessDataset testIter = datasets[1];

            trainIter.prepare();
            testIter.prepare();


            Map<String, double[]> evaluatorMetrics = new HashMap<>();
            double avgTrainTimePerEpoch = 0;
            trainingChapter6(trainIter, testIter, numEpochs, trainer, evaluatorMetrics, avgTrainTimePerEpoch);

            trainLoss = evaluatorMetrics.get("train_epoch_SoftmaxCrossEntropyLoss");
            trainAccuracy = evaluatorMetrics.get("train_epoch_Accuracy");
            testAccuracy = evaluatorMetrics.get("validate_epoch_Accuracy");

            System.out.printf("loss %.3f,", trainLoss[numEpochs - 1]);
            System.out.printf(" train acc %.3f,", trainAccuracy[numEpochs - 1]);
            System.out.printf(" test acc %.3f\n", testAccuracy[numEpochs - 1]);
            System.out.printf("%.1f examples/sec", trainIter.size() / (avgTrainTimePerEpoch / Math.pow(10, 9)));
            System.out.println();

            String[] lossLabel = new String[trainLoss.length + testAccuracy.length + trainAccuracy.length];

            Arrays.fill(lossLabel, 0, trainLoss.length, "train loss");
            Arrays.fill(lossLabel, trainAccuracy.length, trainLoss.length + trainAccuracy.length, "train acc");
            Arrays.fill(lossLabel, trainLoss.length + trainAccuracy.length,
                    trainLoss.length + testAccuracy.length + trainAccuracy.length, "test acc");

            model.save(modelDir, "ninCovid19");

            // save labels into model directory
            Covid19Models.saveSynset(modelDir, dataset.getSynset());

            Table data = Table.create("Data").addColumns(
                    DoubleColumn.create("epoch", ArrayUtils.addAll(epochCount, ArrayUtils.addAll(epochCount, epochCount))),
                    DoubleColumn.create("metrics", ArrayUtils.addAll(trainLoss, ArrayUtils.addAll(trainAccuracy, testAccuracy))),
                    StringColumn.create("lossLabel", lossLabel)
            );
            //画图

        }

    }

    public static SequentialBlock niNBlock(int numChannels, Shape kernelShape,
                                           Shape strideShape, Shape paddingShape) {

        SequentialBlock tempBlock = new SequentialBlock();
        // numChannels  通道数 滤波器层数
        //kernelShape 卷积核大小
        //strideShape 步幅
        // paddingShape  填充大小
        tempBlock.add(Conv2d.builder()
                .setKernelShape(kernelShape)
                .optStride(strideShape)
                .optPadding(paddingShape)
                .setFilters(numChannels)
                .build())
                .add(Activation::relu)
                .add(Conv2d.builder()
                        .setKernelShape(new Shape(1, 1))
                        .setFilters(numChannels)
                        .build())
                .add(Activation::relu)
                .add(Conv2d.builder()
                        .setKernelShape(new Shape(1, 1))
                        .setFilters(numChannels)
                        .build())
                .add(Activation::relu);

        return tempBlock;
    }


    public static void trainingChapter6(RandomAccessDataset trainIter, RandomAccessDataset testIter,
                                        int numEpochs, Trainer trainer, Map<String, double[]> evaluatorMetrics, double avgTrainTimePerEpoch) throws IOException, TranslateException {

        trainer.setMetrics(new Metrics());

        EasyTrain.fit(trainer, numEpochs, trainIter, testIter);

        Metrics metrics = trainer.getMetrics();

        trainer.getEvaluators().stream()
                .forEach(evaluator -> {
                    evaluatorMetrics.put("train_epoch_" + evaluator.getName(), metrics.getMetric("train_epoch_" + evaluator.getName()).stream()
                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());
                    evaluatorMetrics.put("validate_epoch_" + evaluator.getName(), metrics.getMetric("validate_epoch_" + evaluator.getName()).stream()
                            .mapToDouble(x -> x.getValue().doubleValue()).toArray());
                });

        avgTrainTimePerEpoch = metrics.mean("epoch");
    }

    public static Model getNiNModel() {
        SequentialBlock block = new SequentialBlock();
        //构建NiN网络
        block.add(niNBlock(96, new Shape(11, 11), new Shape(4, 4), new Shape(0, 0)))
                .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
                .add(niNBlock(256, new Shape(5, 5), new Shape(1, 1), new Shape(2, 2)))
                .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
                .add(niNBlock(384, new Shape(3, 3), new Shape(1, 1), new Shape(1, 1)))
                .add(Pool.maxPool2dBlock(new Shape(3, 3), new Shape(2, 2)))
                //构建 Dropout 层
                .add(Dropout.builder().optRate(0.5f).build())
                // There are 10 label classes
                .add(niNBlock(10, new Shape(3, 3), new Shape(1, 1), new Shape(1, 1)))
                // The global average pooling layer automatically sets the window shape
                // to the height and width of the input 平均池化层
                .add(Pool.globalAvgPool2dBlock())
                // Transform the four-dimensional output into two-dimensional output
                // with a shape of (batch size, 10)
                .add(Blocks.batchFlattenBlock());

        Model model = Model.newInstance("ninCovid19");
        model.setBlock(block);
        return model;
    }
}

准备训练数据的代码见   文章  JAVA深度学习框架DJL之COVID19 x-ray图片分类  里面的代码

还需要pom文件添加

       <dependency>
            <groupId>tech.tablesaw</groupId>
            <artifactId>tablesaw-jsplot</artifactId>
            <version>0.30.4</version>
        </dependency>

### 如何在 Spring Boot 中集成深度学习算法 #### 使用 DJL 和 Spring Boot 构建深度学习应用 通过结合 Deep Java Library (DJL) 和 Spring Boot,开发者能够迅速创建具备深度学习能力的应用程序。DJL 是由 AWS 开源的项目,专为 Java 生态系统设计,支持多种主流深度学习引擎如 PyTorch, MXNet 等。 为了启动基于这两者的服务端点,在 `pom.xml` 文件里加入必要的依赖项: ```xml <dependency> <groupId>ai.djl.spring.boot.starter</groupId> <artifactId>djl-spring-boot-starter</artifactId> <version>0.7.0</version> </dependency> ``` 定义 RESTful API 来接收输入并返回预测结果: ```java @RestController @RequestMapping("/api/djl") public class DjlController { @PostMapping("/predict") public ResponseEntity<String> predict(@RequestBody Map<String,Object> payload){ try { // 加载预训练模型 Criteria<Image, Classifications> criteria = Criteria.builder() .setTypes(Image.class, Classifications.class) .optModelUrls("https://example.com/model.zip") // 替换成实际路径 .build(); try (ZooModel<Image, Classifications> model = ModelZoo.loadModel(criteria)) { Translator<Image, Classifications> translator = model.getTranslator(); Image img = ImageFactory.getInstance().fromUrl((String)payload.get("image_url")); Classifications classifications = model.predict(img); return new ResponseEntity<>(classifications.toString(), HttpStatus.OK); } } catch(Exception e){ logger.error(e.getMessage()); return new ResponseEntity<>("Error occurred",HttpStatus.INTERNAL_SERVER_ERROR); } } } ``` 此段代码实现了图像分类的功能,其中加载了一个远程存储的预训练模型,并利用传入 URL 获取待识别图片完成推断操作[^2]。 #### 利用 Deeplearning4j 进行分布式训练 对于更大规模的数据集或是复杂度更高的任务,则可以选择 Deeplearning4j(DL4J),它不仅提供了丰富的API接口用于构建各种类型的神经网络结构,还特别适合企业级应用场景下的高性能计算需求。DL4J 支持 Spark 集成从而允许跨节点执行大规模矩阵运算加速模型收敛速度。 配置 Maven 工程引入 DL4J 库文件: ```xml <!-- https://mvnrepository.com/artifact/org.deeplearning4j/deeplearning4j-core --> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-core</artifactId> <version>1.0.0-beta7</version> </dependency> <!-- 如果计划使用Spark进行分布式的训练还需要额外添加spark相关依赖 --> <dependency> <groupId>org.datavec</groupId> <artifactId>datavec-spark_2.11</artifactId> <version>1.0.0-beta7</version> </dependency> ``` 编写简单的多层感知器(MLP)来解决二元分类问题的例子如下所示: ```java MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(seed) .updater(new Adam()) .list() .layer(new DenseLayer.Builder().nIn(numInputs).nOut(hiddenNodes) .activation(Activation.RELU) .weightInit(WeightInit.XAVIER) .build()) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX) .nIn(hiddenNodes).nOut(outputNum).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); DataSetIterator iterator = ... ;// 初始化数据集迭代器 for(int i=0;i<numEpochs;i++){ while(iterator.hasNext()){ DataSet next = iterator.next(); net.fit(next); } } System.out.println(net.evaluate(testData)); ``` 这段脚本描述了怎样设置一个多隐藏层的人工神经网路来进行监督式的学习过程[^3]. #### TensorFlow Serving with Spring Boot 另一种流行的方式是在 Spring Boot 上部署 Tensorflow 模型作为 gRPC 或 HTTP 推理服务器的一部分。这种方式的优势在于可以直接调用已经训练好的 TF SavedModels 而无需重新编码整个流程;同时得益于官方提供的客户端 SDK ,使得与其他微服务之间的交互变得异常简单快捷。 首先确保安装好 tensorflow-serving-api 并将其余所需 jar 添加到项目的类路径下: ```bash pip install tensorflow-serving-api==${TF_VERSION} # Python环境内运行该命令获取对应版本号 ``` 接着参照官方文档说明调整 application.properties 设置参数指向本地或云端托管的服务地址: ```properties tensorflow.serving.host=localhost tensorflow.serving.port=8500 ``` 最后一步就是封装请求体发送给目标主机等待响应解析即可得到最终结论[^4].
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

非ban必选

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值