DJL(Deep Java Library)是一个开源的深度学习框架,由AWS推出,DJL支持多种深度学习后端,包括但不限于:
MXNet:由Apache软件基金会支持的开源深度学习框架。
PyTorch:广泛使用的开源机器学习库,由Facebook的AI研究团队开发。
TensorFlow:由Google开发的另一个流行的开源机器学习框架。
DJL与Java生态系统紧密集成,可以与Spring Boot、Quarkus等Java框架协同工作。
maven
<!-- djl-->
<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>0.28.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-engine</artifactId>
<version>0.28.0</version>
</dependency>
<dependency>
<groupId>ai.djl.pytorch</groupId>
<artifactId>pytorch-model-zoo</artifactId>
<version>0.28.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>0.28.0</version>
</dependency>
<dependency>
<groupId>ai.djl</groupId>
<artifactId>model-zoo</artifactId>
<version>0.28.0</version>
</dependency>
<!-- /djl-->
Java DJL 架构图
┌──────────────────────────────┐
│ ModelZoo │
├──────────────────────────────┤
│ Model │
└───────────────┬──────────────┘
│
┌─────────▼─────────┐
│ Engine │
└───────┬─┬─────────┘
│ │
┌───────▼─▼─────────┐
│ NDManager │
└───────┬─┬─────────┘
│ │
┌─────────▼─▼───────────┐
│ Dataset
└─────────┬─────────────┘
│
┌─────────▼─────────────┐
│ Trainer / Predictor │
└───────────────────────┘
主要组件详细描述
1. ModelZoo 和 Model
-
ModelZoo:提供多种预训练模型
ModelZoo
的功能- 模型发现与下载:
ModelZoo
提供了一种机制,可以从多种来源(例如模型提供商、在线仓库等)发现和下载预训练模型。- 例如,可以从 AWS S3、Hugging Face、TensorFlow Hub 等平台下载模型。
- 模型加载:
ModelZoo
提供了方便的方法来加载模型,用户可以根据需求加载不同类型的模型(例如图像分类模型、对象检测模型、自然语言处理模型等)。- 加载模型时,可以指定模型的名称、版本、以及模型的参数配置。
- 模型管理:
ModelZoo
帮助用户管理已下载和加载的模型,可以方便地查看、更新和删除模型。- 通过这种方式,可以有效地管理本地的模型资源,避免重复下载和浪费存储空间。
示例
import ai.djl.Application import ai.djl.Model import ai.djl.ModelException import ai.djl.modality.Classifications import ai.djl.modality.cv.Image import ai.djl.repository.zoo.Criteria import ai.djl.repository.zoo.ModelZoo import ai.djl.translate.TranslateException object ModelZooExample { @Throws(ModelException::class, TranslateException::class) @JvmStatic fun main(args: Array<String>) { // 定义模型的标准 val criteria: Criteria<Image, Classifications> = Criteria.builder() .optApplication(Application.CV.IMAGE_CLASSIFICATION) // 应用场景:图像分类 .setTypes(Image::class.java, Classifications::class.java) // 输入输出类型 .optFilter("backbone", "resnet50") // 模型过滤条件 .build() // 从 ModelZoo 加载模型 val model: Model = ModelZoo.loadModel(criteria) // 使用模型进行推理 // ... } }
ModelZoo
的类与接口ModelZoo
:核心类,提供模型的下载和加载功能。Criteria
:定义模型加载的标准和过滤条件,用于指定所需模型的应用场景、输入输出类型等。ModelLoader
:用于实际执行模型的下载和加载操作。
- 模型发现与下载:
-
Model:表示一个深度学习模型的接口,包含模型的加载、保存和运行等操作。
-
ai.djl.ModelZoo
Key Methods:
Model loadModel(Criteria<?, ?> criteria)
: Loads a model based on the provided criteria.ModelInfo getModel(ModelId modelId)
: Retrieves information about a specific model using itsModelId
.Set<ModelId> listModels(ZooModel<?, ?> model)
: Lists all models in the zoo that match the given model.
ai.djl.ModelInfo
InterfaceModelInfo
provides metadata about a model, including its name, description, and input/output information.Key Methods:
String getName()
: Returns the name of the model.String getDescription()
: Provides a description of the model.Shape getInputShape()
: Returns the shape of the input tensor.Shape getOutputShape()
: Returns the shape of the output tensor.
ai.djl.ModelId
ClassModelId
uniquely identifies a model in the model zoo. It includes information about the model’s group, name, and version.Key Fields:
String getGroup()
: Gets the group name of the model.String getName()
: Gets the name of the model.String getVersion()
: Gets the version of the model.
ai.djl.Application
EnumApplication
enumerates different types of applications supported by the model zoo, such as IMAGE_CLASSIFICATION, OBJECT_DETECTION, etc.Key Values:
CV.IMAGE_CLASSIFICATION
CV.OBJECT_DETECTION
NLP.TEXT_CLASSIFICATION
ai.djl.Criteria
ClassCriteria
is a builder for creating criteria objects used to filter and load models.Key Methods:
static Builder<?, ?> builder()
: Creates a new builder instance.Criteria<I, O> optApplication(Application application)
: Sets the application type.Criteria<I, O> optEngine(String engine)
: Specifies the engine to use (e.g., MXNet, PyTorch)
example
import ai.djl.Model import ai.djl.ModelException import ai.djl.modality.Classifications import ai.djl.modality.cv.Image import ai.djl.modality.cv.ImageFactory import ai.djl.ndarray.NDList import ai.djl.translate.TranslateException import ai.djl.translate.Translator import ai.djl.translate.TranslatorContext import java.io.IOException import java.nio.file.Paths object DjlExample { @JvmStatic fun main(args: Array<String>) { // 模型路径 val modelDir = Paths.get("models") val modelName = "resnet18" try { Model.newInstance(modelName).use { model -> // 加载模型 model.load(modelDir) // 加载输入图像 val img = ImageFactory.getInstance().fromFile(Paths.get(&
-