JAVA深度学习框架DJL之COVID19 x-ray图片分类

本文介绍了如何从Kaggle下载COVID-19肺部X光数据,使用Djl库训练ResNetV1模型进行肺部疾病分类,并演示了模型的预测过程。通过设置数据预处理、损失函数和训练参数,实现肺炎与正常肺部区分。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

1、从kaggle 下载数据,

如上图所示 covid文件夹下是covid19感染者肺部x-ray图片,图片像素256*256。normal是正常人肺部x-ray图片,图片像素1024*1024。pneumonia是其他类型肺炎肺部x-ray图片,图片像素1024*1024。

2、训练类

/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package com.example.demo.djl.covid19;

import ai.djl.Model;
import ai.djl.basicdataset.ImageFolder;
import ai.djl.metric.Metrics;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.*;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.translate.TranslateException;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;

/**
 * In training, multiple passes (or epochs) are made over the training data trying to find patterns
 * and trends in the data, which are then stored in the model. During the process, the model is
 * evaluated for accuracy using the validation data. The model is updated with findings over each
 * epoch, which improves the accuracy of the model.
 */
public final class Covid19Training {

    // represents number of training samples processed before the model is updated
    private static final int BATCH_SIZE = 32;

    // the number of passes over the complete dataset
    private static final int EPOCHS = 7;

    public static void main(String[] args) throws IOException, TranslateException {
        System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");
        // 设置模型存放目录
        Path modelDir = Paths.get("covidmodels");

        // 初始化数据
        ImageFolder dataset = initDataset("D:\\covid19dataset\\COVID-19 Radiography Database\\");
        // 设置训练数据和验证数据
        RandomAccessDataset[] datasets = dataset.randomSplit(8, 2);

        // set loss function, which seeks to minimize errors
        // loss function evaluates model's predictions against the correct answer (during training)
        // higher numbers are bad - means model performed poorly; indicates more errors; want to
        // minimize errors (loss)
        Loss loss = Loss.softmaxCrossEntropyLoss();

        // setting training parameters (ie hyperparameters)
        TrainingConfig config = setupTrainingConfig(loss);

        try (Model model = Covid19Models.getModel(); // empty model instance to hold patterns
             Trainer trainer = model.newTrainer(config)) {
            // metrics collect and report key performance indicators, like accuracy
            trainer.setMetrics(new Metrics());
            // 3*224*224 3通道图
            Shape inputShape = new Shape(1, 3, Covid19Models.IMAGE_HEIGHT, Covid19Models.IMAGE_HEIGHT);

            // initialize trainer with proper input shape
            trainer.initialize(inputShape);

            // find the patterns in data
            EasyTrain.fit(trainer, EPOCHS, datasets[0], datasets[1]);

            // set model properties
            TrainingResult result = trainer.getTrainingResult();
            model.setProperty("Epoch", String.valueOf(EPOCHS));
            model.setProperty(
                    "Accuracy", String.format("%.5f", result.getValidateEvaluation("Accuracy")));
            model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));

            // save the model after done training for inference later
            // model saved as shoeclassifier-0000.params
            model.save(modelDir, Covid19Models.MODEL_NAME);

            // save labels into model directory
            Covid19Models.saveSynset(modelDir, dataset.getSynset());
        }
    }

    private static ImageFolder initDataset(String datasetRoot) throws IOException, TranslateException {
        ImageFolder dataset =
                ImageFolder.builder()
                        // retrieve the data
                        .setRepositoryPath(Paths.get(datasetRoot))
                        .optMaxDepth(10)
                        .addTransform(new Resize(Covid19Models.IMAGE_WIDTH, Covid19Models.IMAGE_HEIGHT))
                        .addTransform(new ToTensor())
                        // random sampling; don't process the data in order
                        .setSampling(BATCH_SIZE, true)
                        .build();

        dataset.prepare();
        return dataset;
    }

    private static TrainingConfig setupTrainingConfig(Loss loss) {
        return new DefaultTrainingConfig(loss)
                .addEvaluator(new Accuracy())
                .addTrainingListeners(TrainingListener.Defaults.logging());
    }
}

3、构建cnn网络类

/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package com.example.demo.djl.covid19;

import ai.djl.Model;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.convolutional.Conv3d;
import ai.djl.nn.core.Linear;

import java.io.IOException;
import java.io.Writer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;

/**
 * A helper class loads and saves model.
 */
public final class Covid19Models {

    // the number of classification labels: pneumonia, normal, covid
    public static final int NUM_OF_OUTPUT = 3;

    // the height and width for pre-processing of the image
    public static final int IMAGE_HEIGHT = 224;
    public static final int IMAGE_WIDTH = 224;

    // the name of the model
    public static final String MODEL_NAME = "covidclassifier";

    private Covid19Models() {
    }

    public static Model getModel() {
        // create new instance of an empty model
        Model model = Model.newInstance(MODEL_NAME);

        Block resNet50 =
                ResNetV1.builder()// construct the network
                        .setImageShape(new Shape(3, IMAGE_HEIGHT, IMAGE_WIDTH))
                        .setNumLayers(50)
                        .setOutSize(NUM_OF_OUTPUT)
                        .build();

        // set the neural network to the model
        model.setBlock(resNet50);
        return model;
    }

    public static void saveSynset(Path modelDir, List<String> synset) throws IOException {
        Path synsetFile = modelDir.resolve("synset.txt");
        try (Writer writer = Files.newBufferedWriter(synsetFile)) {
            writer.write(String.join("\n", synset));
        }
    }
}

4、使用生成的模型预测

/*
 * Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
 * with the License. A copy of the License is located at
 *
 * http://aws.amazon.com/apache2.0/
 *
 * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
 * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
 * and limitations under the License.
 */
package com.example.demo.djl.covid19;

import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;

import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;

/**
 * Uses the model to generate a prediction called an inference
 */
public class Covid19Inference {

    public static void main(String[] args) throws ModelException, TranslateException, IOException {
        // the location where the model is saved
        Path modelDir = Paths.get("covidmodels");
        System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");
        // the path of image to classify
        String imageFilePath;
        if (args.length == 0) {
            imageFilePath = "D:\\covid19dataset\\COVID-19 Radiography Database\\covid\\COVID (118).png";
//            imageFilePath = "D:\\covid19dataset\\COVID-19 Radiography Database\\normal\\NORMAL (137).png";
//            imageFilePath = "D:\\covid19dataset\\COVID-19 Radiography Database\\pneumonia\\Viral Pneumonia (96).png";
        } else {
            imageFilePath = args[0];
        }

        // Load the image file from the path
        Image img = ImageFactory.getInstance().fromFile(Paths.get(imageFilePath));

        try (Model model = Covid19Models.getModel()) { // empty model instance
            // load the model
            model.load(modelDir, Covid19Models.MODEL_NAME);

            // define a translator for pre and post processing
            // out of the box this translator converts images to ResNet friendly ResNet 18 shape
            Translator<Image, Classifications> translator =
                    ImageClassificationTranslator.builder()
                            .addTransform(new Resize(Covid19Models.IMAGE_WIDTH, Covid19Models.IMAGE_HEIGHT))
                            .addTransform(new ToTensor())
                            .optApplySoftmax(true)
                            .build();

            // run the inference using a Predictor
            try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) {
                // holds the probability score per label
                Classifications predictResult = predictor.predict(img);
                System.out.println(predictResult);
            }
        }
    }
}

免责声明:本实例仅用于学习,不可用于商业或实际医学场景。若在实际场景中使用,后果自负。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

非ban必选

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值