Deep Java Library (DJL) 提供了一套丰富的 API,使得 Java 开发者可以轻松地进行深度学习任务,包括模型加载、训练、推理等。以下是 DJL 的主要模块和 API 的详细介绍。
1. 核心模块
1.1 Model
Model 类是 DJL 中的核心类之一,用于表示深度学习模型。它可以加载预训练模型或自定义模型,并提供训练和推理的方法。
import ai.djl.Model;
// 创建一个空模型
Model model = Model.newInstance("myModel");
// 加载预训练模型
ZooModel<Image, Classifications> loadedModel = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optModelName("resnet18_v1")
.optEngine("PyTorch")
.optProgress(new ProgressBar())
.build()
.loadModel();
1.2 Predictor
Predictor 类用于执行模型的推理操作。通过 Predictor,可以将输入数据传递给模型并获取预测结果。
import ai.djl.inference.Predictor;
try (Predictor<Image, Classifications> predictor = loadedModel.newPredictor()) {
Image img = ImageFactory.getInstance().fromFile(Paths.get("path/to/your/image.jpg"));
Classifications result = predictor.predict(img);
System.out.println("预测结果: " + result);
}
2. 数据处理模块
2.1 NDArray
NDArray 类用于表示多维数组,类似于 NumPy 中的数组。它是进行张量操作的基础。
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
try (NDManager manager = NDManager.newBaseManager()) {
NDArray array = manager.create(new float[]{1.0f, 2.0f, 3.0f});
System.out.println(array);
}
2.2 Transform
Transform 接口用于定义数据预处理和后处理的步骤。常见的预处理步骤包括图像缩放、归一化等。
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.translate.Pipeline;
Pipeline pipeline = new Pipeline();
pipeline.add(new Resize(224, 224));
pipeline.add(new ToTensor());
3. 模型动物园
3.1 Criteria
Criteria 类用于定义模型的选择标准,包括模型类型、应用领域、框架等
import ai.djl.repository.zoo.Criteria;
Criteria<Image, Classifications> criteria = Criteria.builder()
.setTypes(Image.class, Classifications.class)
.optApplication(ai.djl.modality.cv.Application.IMAGE_CLASSIFICATION)
.optModelName("resnet18_v1")
.optEngine("PyTorch")
.optProgress(new ProgressBar())
.build();
3.2 ZooModel
ZooModel 类表示从模型动物园中加载的模型。通过 Criteria 可以加载预训练模型。
import ai.djl.repository.zoo.ZooModel;
try (ZooModel<Image, Classifications> model = criteria.loadModel()) {
// 使用模型进行推理
}
4. 训练模块
4.1 Trainer
Trainer 类用于训练模型。它管理模型的优化器、损失函数和评估指标。
import ai.djl.training.Trainer;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.loss.Loss;
import ai.djl.training.evaluator.Evaluator;
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.optOptimizer(Optimizer.adam())
.addEvaluator(new Accuracy())
.addTrainingListeners(TrainingListener.Defaults.logging());
try (Trainer trainer = model.newTrainer(config)) {
// 训练模型
}
4.2 Dataset
Dataset 类用于表示数据集。它可以加载和迭代数据,支持多种数据源。
import ai.djl.dataset.api.Dataset;
import ai.djl.dataset.cv.classification.ImageFolder;
Dataset dataset = ImageFolder.builder()
.setSampling(100, true)
.optUsage(Dataset.Usage.TRAIN)
.addTransform(pipeline)
.build();
dataset.prepare();
5. 其他模块
5.1 Translator
Translator 接口用于定义输入和输出的转换逻辑。它将原始数据转换为模型可以接受的格式,并将模型的输出转换为用户友好的格式。
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
class MyTranslator implements Translator<Image, Classifications> {
@Override
public Batch processInput(TranslatorContext ctx, Image input) {
// 将图像转换为 NDArray
return ctx.batchify(input.toNDList(ctx.getNDManager()));
}
@Override
public Classifications processOutput(TranslatorContext ctx, Batch output) {
// 将模型输出转换为分类结果
return Classifications.of(output.getOutputs().get(0));
}
}
总结
通过以上介绍,我们可以看到 DJL 提供了一套完整的 API,涵盖了从数据处理、模型加载、训练到推理的各个环节。这些 API 使得 Java 开发者可以更加方便地进行深度学习任务,而无需深入了解底层的复杂细节。希望这些介绍对你理解和使用 DJL 有所帮助