1、从kaggle 下载数据,
如上图所示 covid文件夹下是covid19感染者肺部x-ray图片,图片像素256*256。normal是正常人肺部x-ray图片,图片像素1024*1024。pneumonia是其他类型肺炎肺部x-ray图片,图片像素1024*1024。
2、训练类
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package com.example.demo.djl.covid19;
import ai.djl.Model;
import ai.djl.basicdataset.ImageFolder;
import ai.djl.metric.Metrics;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.*;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
/**
* In training, multiple passes (or epochs) are made over the training data trying to find patterns
* and trends in the data, which are then stored in the model. During the process, the model is
* evaluated for accuracy using the validation data. The model is updated with findings over each
* epoch, which improves the accuracy of the model.
*/
public final class Covid19Training {
// represents number of training samples processed before the model is updated
private static final int BATCH_SIZE = 32;
// the number of passes over the complete dataset
private static final int EPOCHS = 7;
public static void main(String[] args) throws IOException, TranslateException {
System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");
// 设置模型存放目录
Path modelDir = Paths.get("covidmodels");
// 初始化数据
ImageFolder dataset = initDataset("D:\\covid19dataset\\COVID-19 Radiography Database\\");
// 设置训练数据和验证数据
RandomAccessDataset[] datasets = dataset.randomSplit(8, 2);
// set loss function, which seeks to minimize errors
// loss function evaluates model's predictions against the correct answer (during training)
// higher numbers are bad - means model performed poorly; indicates more errors; want to
// minimize errors (loss)
Loss loss = Loss.softmaxCrossEntropyLoss();
// setting training parameters (ie hyperparameters)
TrainingConfig config = setupTrainingConfig(loss);
try (Model model = Covid19Models.getModel(); // empty model instance to hold patterns
Trainer trainer = model.newTrainer(config)) {
// metrics collect and report key performance indicators, like accuracy
trainer.setMetrics(new Metrics());
// 3*224*224 3通道图
Shape inputShape = new Shape(1, 3, Covid19Models.IMAGE_HEIGHT, Covid19Models.IMAGE_HEIGHT);
// initialize trainer with proper input shape
trainer.initialize(inputShape);
// find the patterns in data
EasyTrain.fit(trainer, EPOCHS, datasets[0], datasets[1]);
// set model properties
TrainingResult result = trainer.getTrainingResult();
model.setProperty("Epoch", String.valueOf(EPOCHS));
model.setProperty(
"Accuracy", String.format("%.5f", result.getValidateEvaluation("Accuracy")));
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
// save the model after done training for inference later
// model saved as shoeclassifier-0000.params
model.save(modelDir, Covid19Models.MODEL_NAME);
// save labels into model directory
Covid19Models.saveSynset(modelDir, dataset.getSynset());
}
}
private static ImageFolder initDataset(String datasetRoot) throws IOException, TranslateException {
ImageFolder dataset =
ImageFolder.builder()
// retrieve the data
.setRepositoryPath(Paths.get(datasetRoot))
.optMaxDepth(10)
.addTransform(new Resize(Covid19Models.IMAGE_WIDTH, Covid19Models.IMAGE_HEIGHT))
.addTransform(new ToTensor())
// random sampling; don't process the data in order
.setSampling(BATCH_SIZE, true)
.build();
dataset.prepare();
return dataset;
}
private static TrainingConfig setupTrainingConfig(Loss loss) {
return new DefaultTrainingConfig(loss)
.addEvaluator(new Accuracy())
.addTrainingListeners(TrainingListener.Defaults.logging());
}
}
3、构建cnn网络类
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package com.example.demo.djl.covid19;
import ai.djl.Model;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.convolutional.Conv3d;
import ai.djl.nn.core.Linear;
import java.io.IOException;
import java.io.Writer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
/**
* A helper class loads and saves model.
*/
public final class Covid19Models {
// the number of classification labels: pneumonia, normal, covid
public static final int NUM_OF_OUTPUT = 3;
// the height and width for pre-processing of the image
public static final int IMAGE_HEIGHT = 224;
public static final int IMAGE_WIDTH = 224;
// the name of the model
public static final String MODEL_NAME = "covidclassifier";
private Covid19Models() {
}
public static Model getModel() {
// create new instance of an empty model
Model model = Model.newInstance(MODEL_NAME);
Block resNet50 =
ResNetV1.builder()// construct the network
.setImageShape(new Shape(3, IMAGE_HEIGHT, IMAGE_WIDTH))
.setNumLayers(50)
.setOutSize(NUM_OF_OUTPUT)
.build();
// set the neural network to the model
model.setBlock(resNet50);
return model;
}
public static void saveSynset(Path modelDir, List<String> synset) throws IOException {
Path synsetFile = modelDir.resolve("synset.txt");
try (Writer writer = Files.newBufferedWriter(synsetFile)) {
writer.write(String.join("\n", synset));
}
}
}
4、使用生成的模型预测
/*
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package com.example.demo.djl.covid19;
import ai.djl.Model;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.ImageClassificationTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
/**
* Uses the model to generate a prediction called an inference
*/
public class Covid19Inference {
public static void main(String[] args) throws ModelException, TranslateException, IOException {
// the location where the model is saved
Path modelDir = Paths.get("covidmodels");
System.setProperty("DJL_CACHE_DIR", "d:/ai/djl");
// the path of image to classify
String imageFilePath;
if (args.length == 0) {
imageFilePath = "D:\\covid19dataset\\COVID-19 Radiography Database\\covid\\COVID (118).png";
// imageFilePath = "D:\\covid19dataset\\COVID-19 Radiography Database\\normal\\NORMAL (137).png";
// imageFilePath = "D:\\covid19dataset\\COVID-19 Radiography Database\\pneumonia\\Viral Pneumonia (96).png";
} else {
imageFilePath = args[0];
}
// Load the image file from the path
Image img = ImageFactory.getInstance().fromFile(Paths.get(imageFilePath));
try (Model model = Covid19Models.getModel()) { // empty model instance
// load the model
model.load(modelDir, Covid19Models.MODEL_NAME);
// define a translator for pre and post processing
// out of the box this translator converts images to ResNet friendly ResNet 18 shape
Translator<Image, Classifications> translator =
ImageClassificationTranslator.builder()
.addTransform(new Resize(Covid19Models.IMAGE_WIDTH, Covid19Models.IMAGE_HEIGHT))
.addTransform(new ToTensor())
.optApplySoftmax(true)
.build();
// run the inference using a Predictor
try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) {
// holds the probability score per label
Classifications predictResult = predictor.predict(img);
System.out.println(predictResult);
}
}
}
}
免责声明:本实例仅用于学习,不可用于商业或实际医学场景。若在实际场景中使用,后果自负。