【DJL】Springboot+Maven+DJL实现java调取pytorch模型

一、前言

近期学习了DJL(深度java学习),有了一点小的研究成果,特以此博客分享给大家。这个技术是一个特别新的技术,是亚马逊云服务在2019年re:Invent大会推出的专为Java开发者量身定制的深度学习框架,网上的资料比较少,只有官方文档可以参考,研究起来难度比较大,但是经过不懈的努力,终于搞定了,接下来以官网的demo入门。由于这块有很多坑,所以有必要好好的说一下。

官网地址:https://docs.djl.ai/jupyter/load_pytorch_model.html

二、demo
1、创建SpringBoot项目,导入pom依赖
    <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>0.6.0</version>
        </dependency>

        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>0.6.0</version>
            <scope>runtime</scope>
        </dependency>


        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-auto</artifactId>
            <version>1.5.0</version>
            <scope>runtime</scope>
        </dependency>

        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-api</artifactId>
            <version>1.7.26</version>
        </dependency>

        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-simple</artifactId>
            <version>1.7.26</version>
        </dependency>

        <dependency>
            <groupId>net.java.dev.jna</groupId>
            <artifactId>jna</artifactId>
            <version>5.3.0</version>
        </dependency>

        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-cpu</artifactId>
            <classifier>win-x86_64</classifier>
            <scope>runtime</scope>
            <version>1.5.0</version>
        </dependency>
 </dependencies>
2、下载模型
    DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz", "build/pytorch_models/resnet18/resnet18.pt", new ProgressBar());
        DownloadUtils.download("https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt", "build/pytorch_models/resnet18/synset.txt", new ProgressBar());
       

下载完成后生成一个build文件夹,里面有你的模型(如下下载失败,请翻墙,连接外网)
在这里插入图片描述

3、创建一个Translator
 Pipeline pipeline = new Pipeline();
        pipeline.add(new Resize(256))
                .add(new CenterCrop(224, 224))
                .add(new ToTensor())
                .add(new Normalize(
                        new float[]{0.485f, 0.456f, 0.406f},
                        new float[]{0.229f, 0.224f, 0.225f}));

        Translator<Image, Classifications> translator = ImageClassificationTranslator.builder()
                .setPipeline(pipeline)
                .optApplySoftmax(true)
                .build();
4、加载你的模型resnet18
   System.setProperty("ai.djl.repository.zoo.location", "build/pytorch_models/resnet18");
    Criteria<Image, Classifications> criteria = Criteria.builder()
            .setTypes(Image.class, Classifications.class)
            // only search the model in local directory
            // "ai.djl.localmodelzoo:{name of the model}"
            .optArtifactId("ai.djl.localmodelzoo:resnet18")
            .optTranslator(translator)
            .optProgress(new ProgressBar()).build();
            ZooModel model = ModelZoo.loadModel(criteria);
5、使用图片进行预测
		// 自己本地
	    File fs=new File("D:\\testdjl\\dog.jpg");
        Image img = ImageFactory.getInstance().fromInputStream(new FileInputStream(fs));
        Predictor<Image, Classifications> predictor = model.newPredictor();
        Classifications classifications = predictor.predict(img);
        System.out.println(classifications);
6、执行结果

在这里插入图片描述

三、在运行的时候可能会报如下的错
1、No deep learning engine found

官网给出地址如下:
https://github.com/awslabs/djl/blob/master/docs/development/troubleshooting.md

我是 通过解决了nsatisfiedLinkError问题,解决的No deep learning engine found错误,官网有一个提示:CN:如果您在中国,可以使用DirectX修复工具来安装遗失依赖项。所以我就试着通过DirectX这个修复工具进行修复。
下载地址:https://www.onlinedown.net/soft/120082.htm,下载完之后安装就行了,如下:
在这里插入图片描述
2、路径中有中文(open file faild)在这里插入图片描述
将中文改成英文就好了

3、下载模型失败,记得翻墙

后记:

如果对你有所帮助,请记得点赞。

### 使用Java和Deep Java Library (DJL) 实现手写数字识别 #### 准备工作 为了使用DJL进行手写数字识别,首先需要设置开发环境并安装必要的依赖项。这可以通过Maven或Gradle来完成,在`pom.xml`文件中添加如下依赖: ```xml <dependency> <groupId>ai.djl</groupId> <artifactId>djl-deps</artifactId> <version>0.19.0</version> <type>pom</type> </dependency> <!-- 添加支持的后端引擎 --> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-engine</artifactId> <version>0.19.0</version> </dependency> ``` 确保已正确配置了Java开发环境,并导入上述依赖。 #### 加载预训练模型 可以加载一个已经针对MNIST数据集训练好的模型来进行手写数字识别。这里展示如何下载并加载这样的模型: ```java import ai.djl.Model; import ai.djl.inference.Predictor; import ai.djl.modality.cv.ImageFactory; import ai.djl.repository.zoo.Criteria; import ai.djl.translate.TranslateException; public class HandwrittenDigitRecognition { private static final String MODEL_NAME = "mnist"; private static final Criteria<Image, Classifications> criteria = Criteria.builder() .setTypes(Image.class, Classifications.class) .optModelName(MODEL_NAME) .build(); public static void main(String[] args) throws TranslateException { try (Model model = Model.newInstance("mnist")) { model.load(new File("/path/to/model/directory")); // 创建预测器 try (Predictor<Image, Classifications> predictor = model.newPredictor(criteria)) { Image img = ImageFactory.getInstance().fromFile(new File("/path/to/image.png")); // 进行预测 Classifications classifications = predictor.predict(img); System.out.println(classifications.best()); } } catch (IOException e) { throw new RuntimeException(e); } } } ``` 这段代码展示了如何定义用于图像输入和分类输出的标准,以及怎样创建一个预测器对象执行具体的推理操作[^1]。 #### 数据准备与转换 当处理自定义的手写图片时,需将其调整至适合喂入神经网络的形式——通常是灰度化、缩放尺寸为28x28像素大小,并标准化数值范围在[0, 1]之间。这部分逻辑可以在调用predict方法之前加入到程序里。 #### 测试与验证 最后一步是对几个测试样本运行该应用程序,观察其能否准确无误地给出预期的结果。如果遇到问题,则可能需要进一步调试参数设定或是尝试其他更复杂的架构设计。
评论 26
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值