java(kotlin) ai框架djl

DJL(Deep Java Library)是一个开源的深度学习框架,由AWS推出,DJL支持多种深度学习后端,包括但不限于:

MXNet:由Apache软件基金会支持的开源深度学习框架。
PyTorch:广泛使用的开源机器学习库,由Facebook的AI研究团队开发。
TensorFlow:由Google开发的另一个流行的开源机器学习框架。
DJL与Java生态系统紧密集成,可以与Spring Boot、Quarkus等Java框架协同工作。

maven

 <!--        djl-->
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.28.0</version>
        </dependency>
  <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>0.28.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-model-zoo</artifactId>
            <version>0.28.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
            <version>0.28.0</version>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
            <version>0.28.0</version>
        </dependency>
        <!--        /djl-->

Java DJL 架构图

┌──────────────────────────────┐
│          ModelZoo            │
├──────────────────────────────┤
│            Model             │
└───────────────┬──────────────┘
                │
      ┌─────────▼─────────┐
      │       Engine      │
      └───────┬─┬─────────┘
              │ │
      ┌───────▼─▼─────────┐
      │     NDManager     │
      └───────┬─┬─────────┘
              │ │
    ┌─────────▼─▼───────────┐
    │    Dataset 
    └─────────┬─────────────┘
              │
    ┌─────────▼─────────────┐
    │  Trainer / Predictor  │
    └───────────────────────┘

主要组件详细描述

1. ModelZoo 和 Model
  • ModelZoo:提供多种预训练模型

    ModelZoo 的功能
    1. 模型发现与下载
      • ModelZoo 提供了一种机制,可以从多种来源(例如模型提供商、在线仓库等)发现和下载预训练模型。
      • 例如,可以从 AWS S3、Hugging Face、TensorFlow Hub 等平台下载模型。
    2. 模型加载
      • ModelZoo 提供了方便的方法来加载模型,用户可以根据需求加载不同类型的模型(例如图像分类模型、对象检测模型、自然语言处理模型等)。
      • 加载模型时,可以指定模型的名称、版本、以及模型的参数配置。
    3. 模型管理
      • ModelZoo 帮助用户管理已下载和加载的模型,可以方便地查看、更新和删除模型。
      • 通过这种方式,可以有效地管理本地的模型资源,避免重复下载和浪费存储空间。

    示例

    
    import ai.djl.Application
    import ai.djl.Model
    import ai.djl.ModelException
    import ai.djl.modality.Classifications
    import ai.djl.modality.cv.Image
    import ai.djl.repository.zoo.Criteria
    import ai.djl.repository.zoo.ModelZoo
    import ai.djl.translate.TranslateException
    
    
    object ModelZooExample {
         
        @Throws(ModelException::class, TranslateException::class)
        @JvmStatic
        fun main(args: Array<String>) {
         
            // 定义模型的标准
            val criteria: Criteria<Image, Classifications> = Criteria.builder()
                .optApplication(Application.CV.IMAGE_CLASSIFICATION) // 应用场景:图像分类
                .setTypes(Image::class.java, Classifications::class.java) // 输入输出类型
                .optFilter("backbone", "resnet50") // 模型过滤条件
                .build()
    
            // 从 ModelZoo 加载模型
            val model: Model = ModelZoo.loadModel(criteria)
    
            // 使用模型进行推理
            // ...
        }
    }
    
    
    

    ModelZoo 的类与接口

    • ModelZoo:核心类,提供模型的下载和加载功能。
    • Criteria:定义模型加载的标准和过滤条件,用于指定所需模型的应用场景、输入输出类型等。
    • ModelLoader:用于实际执行模型的下载和加载操作。
  • Model:表示一个深度学习模型的接口,包含模型的加载、保存和运行等操作。

    • ai.djl.ModelZoo

      Key Methods:
      • Model loadModel(Criteria<?, ?> criteria): Loads a model based on the provided criteria.
      • ModelInfo getModel(ModelId modelId): Retrieves information about a specific model using its ModelId.
      • Set<ModelId> listModels(ZooModel<?, ?> model): Lists all models in the zoo that match the given model.

      ai.djl.ModelInfo Interface

      ModelInfo provides metadata about a model, including its name, description, and input/output information.

      Key Methods:
      • String getName(): Returns the name of the model.
      • String getDescription(): Provides a description of the model.
      • Shape getInputShape(): Returns the shape of the input tensor.
      • Shape getOutputShape(): Returns the shape of the output tensor.

      ai.djl.ModelId Class

      ModelId uniquely identifies a model in the model zoo. It includes information about the model’s group, name, and version.

      Key Fields:
      • String getGroup(): Gets the group name of the model.
      • String getName(): Gets the name of the model.
      • String getVersion(): Gets the version of the model.

      ai.djl.Application Enum

      Application enumerates different types of applications supported by the model zoo, such as IMAGE_CLASSIFICATION, OBJECT_DETECTION, etc.

      Key Values:
      • CV.IMAGE_CLASSIFICATION
      • CV.OBJECT_DETECTION
      • NLP.TEXT_CLASSIFICATION

      ai.djl.Criteria Class

      Criteria is a builder for creating criteria objects used to filter and load models.

      Key Methods:
      • static Builder<?, ?> builder(): Creates a new builder instance.
      • Criteria<I, O> optApplication(Application application): Sets the application type.
      • Criteria<I, O> optEngine(String engine): Specifies the engine to use (e.g., MXNet, PyTorch)
      example
      
      import ai.djl.Model
      import ai.djl.ModelException
      import ai.djl.modality.Classifications
      import ai.djl.modality.cv.Image
      import ai.djl.modality.cv.ImageFactory
      import ai.djl.ndarray.NDList
      import ai.djl.translate.TranslateException
      import ai.djl.translate.Translator
      import ai.djl.translate.TranslatorContext
      import java.io.IOException
      import java.nio.file.Paths
      
      
      object DjlExample {
          @JvmStatic
          fun main(args: Array<String>) {
              // 模型路径
              val modelDir = Paths.get("models")
              val modelName = "resnet18"
              try {
                  Model.newInstance(modelName).use { model ->
                      // 加载模型
                      model.load(modelDir)
      
                      // 加载输入图像
                      val img =
                          ImageFactory.getInstance().fromFile(Paths.get(&
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值