ONNX Runtime for Java 实战:基于预训练模型实现图片分类

ONNX Runtime 是一款跨平台、高性能的推理引擎,能够高效运行 ONNX 格式的深度学习模型。Java 作为企业级开发的主流语言,结合 ONNX Runtime 可轻松实现 AI 模型的端到端推理。本文将以图片分类为核心场景,从零搭建一个基于 ONNX Runtime for Java 的完整项目,详细讲解环境配置、模型加载、数据预处理、推理执行及结果解析全流程。

一、实战背景与准备工作

1.1 核心目标

使用预训练的 ResNet-50 ONNX 模型,通过 Java + ONNX Runtime 实现对任意图片的分类推理,输出图片的类别标签及置信度。

1.2 环境依赖

  • JDK 8 及以上(推荐 11/17,兼容 ONNX Runtime 最新版本)
  • Maven 3.6+(项目构建工具)
  • ONNX Runtime Java 包(onnxruntime-gpu/onnxruntime,根据硬件选择)
  • 预训练 ResNet-50 ONNX 模型(可从 ONNX Model Zoo 下载)
  • 辅助依赖:OpenCV(图片预处理)、FastJSON(结果格式化)

1.3 模型与标签准备

  1. 下载 ResNet-50 ONNX 模型:从 ONNX Model Zoo 下载 resnet50-v1-12.onnx,保存至项目 src/main/resources/models 目录。
  2. 下载 ImageNet 标签文件:创建 src/main/resources/labels/imagenet_labels.txt,保存 1000 类 ImageNet 标签(可从 GitHub 转换为文本格式)。

二、项目搭建与依赖配置

2.1 创建 Maven 项目

通过 IDEA/Eclipse 创建 Maven 项目,核心 pom.xml 依赖配置如下:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.example</groupId>
    <artifactId>onnxruntime-java-image-classification</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>11</maven.compiler.source>
        <maven.compiler.target>11</maven.compiler.target>
        <onnxruntime.version>1.17.0</onnxruntime.version>
        <opencv.version>4.8.0</opencv.version>
        <fastjson.version>2.0.43</fastjson.version>
    </properties>

    <dependencies>
        <!-- ONNX Runtime Java 核心依赖(CPU版),GPU版替换为 onnxruntime-gpu -->
        <dependency>
            <groupId>com.microsoft.onnxruntime</groupId>
            <artifactId>onnxruntime</artifactId>
            <version>${onnxruntime.version}</version>
        </dependency>

        <!-- OpenCV 用于图片预处理(读取、缩放、归一化) -->
        <dependency>
            <groupId>org.openpnp</groupId>
            <artifactId>opencv</artifactId>
            <version>${opencv.version}</version>
        </dependency>

        <!-- FastJSON 用于结果格式化输出 -->
        <dependency>
            <groupId>com.alibaba.fastjson2</groupId>
            <artifactId>fastjson2</artifactId>
            <version>${fastjson.version}</version>
        </dependency>

        <!-- 测试依赖 -->
        <dependency>
            <groupId>junit</groupId>
            <artifactId>junit</artifactId>
            <version>4.13.2</version>
            <scope>test</scope>
        </dependency>
    </dependencies>

    <build>
        <plugins>
            <!-- 打包时包含依赖 -->
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-shade-plugin</artifactId>
                <version>3.4.1</version>
                <executions>
                    <execution>
                        <phase>package</phase>
                        <goals>
                            <goal>shade</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>
        </plugins>
    </build>
</project>

2.2 目录结构

最终项目目录如下:

onnxruntime-java-image-classification/
├── src/
│   ├── main/
│   │   ├── java/
│   │   │   └── com/
│   │   │       └── example/
│   │   │           ├── ImageClassifier.java  // 核心分类逻辑
│   │   │           ├── OnnxModelLoader.java  // 模型加载工具
│   │   │           └── utils/
│   │   │               ├── ImagePreprocessor.java  // 图片预处理
│   │   │               └── LabelLoader.java  // 标签加载工具
│   │   └── resources/
│   │       ├── models/
│   │       │   └── resnet50-v1-12.onnx  // ResNet-50 模型
│   │       └── labels/
│   │           └── imagenet_labels.txt  // ImageNet 标签
│   └── test/
│       └── java/
│           └── com/
│               └── example/
│                   └── ImageClassifierTest.java  // 测试类
└── pom.xml

三、核心功能实现

3.1 标签加载工具(LabelLoader.java)

读取 imagenet_labels.txt,将标签存储为 List,方便后续根据索引查询类别:

package com.example.utils;

import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;

public class LabelLoader {
    /**
     * 加载标签文件
     * @param labelPath 标签文件路径(resources 下)
     * @return 标签列表
     */
    public static List<String> loadLabels(String labelPath) {
        List<String> labels = new ArrayList<>();
        try (InputStream is = LabelLoader.class.getClassLoader().getResourceAsStream(labelPath);
             BufferedReader br = new BufferedReader(new InputStreamReader(is))) {
            String line;
            while ((line = br.readLine()) != null) {
                labels.add(line.trim());
            }
        } catch (Exception e) {
            throw new RuntimeException("加载标签文件失败", e);
        }
        return labels;
    }
}

3.2 图片预处理工具(ImagePreprocessor.java)

ResNet-50 要求输入为 (1, 3, 224, 224) 的张量,且需进行归一化(均值 [0.485, 0.456, 0.406],标准差 [0.229, 0.224, 0.225])。使用 OpenCV 实现图片读取、缩放、通道转换、归一化:

package com.example.utils;

import org.opencv.core.Core;
import org.opencv.core.Mat;
import org.opencv.core.Size;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;

public class ImagePreprocessor {
    // 静态加载 OpenCV 库
    static {
        System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
    }

    // ResNet-50 输入尺寸
    private static final int INPUT_WIDTH = 224;
    private static final int INPUT_HEIGHT = 224;
    // 归一化均值和标准差(RGB 顺序)
    private static final double[] MEAN = {0.485, 0.456, 0.406};
    private static final double[] STD = {0.229, 0.224, 0.225};

    /**
     * 图片预处理:读取→缩放→BGR转RGB→归一化→转张量格式
     * @param imagePath 图片路径
     * @return 预处理后的浮点型数组 (1, 3, 224, 224)
     */
    public static float[][] preprocess(String imagePath) {
        // 1. 读取图片(OpenCV 默认 BGR 通道)
        Mat image = Imgcodecs.imread(imagePath);
        if (image.empty()) {
            throw new RuntimeException("图片读取失败:" + imagePath);
        }

        // 2. 缩放至 224x224
        Mat resizedImage = new Mat();
        Imgproc.resize(image, resizedImage, new Size(INPUT_WIDTH, INPUT_HEIGHT));

        // 3. 转换为浮点型,归一化到 [0,1]
        resizedImage.convertTo(resizedImage, Core.CV_32F, 1.0 / 255.0);

        // 4. 分离通道(BGR → RGB)并归一化
        Mat[] channels = new Mat[3];
        Core.split(resizedImage, channels);
        for (int i = 0; i < 3; i++) {
            // OpenCV 是 BGR,转换为 RGB:i=0→B→对应索引2,i=1→G→索引1,i=2→R→索引0
            int rgbIndex = 2 - i;
            Core.subtract(channels[i], MEAN[rgbIndex], channels[i]);
            Core.divide(channels[i], STD[rgbIndex], channels[i]);
        }

        // 5. 重组为 (3, 224, 224) 的张量
        Mat rgbImage = new Mat();
        Core.merge(channels, rgbImage);

        // 6. 转换为 float 数组,格式为 [1, 3, 224, 224]
        float[][] inputTensor = new float[1][3 * INPUT_WIDTH * INPUT_HEIGHT];
        float[] pixels = new float[(int) (rgbImage.total() * rgbImage.channels())];
        rgbImage.get(0, 0, pixels);
        inputTensor[0] = pixels;

        // 释放资源
        image.release();
        resizedImage.release();
        rgbImage.release();
        for (Mat ch : channels) {
            ch.release();
        }

        return inputTensor;
    }
}

3.3 ONNX 模型加载工具(OnnxModelLoader.java)

封装 ONNX Runtime 模型加载逻辑,创建推理会话(InferenceSession):

package com.example;

import com.microsoft.onnxruntime.OrtEnvironment;
import com.microsoft.onnxruntime.OrtException;
import com.microsoft.onnxruntime.OrtSession;

import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;

public class OnnxModelLoader {
    private static OrtEnvironment env;

    static {
        // 初始化 ONNX Runtime 环境
        try {
            env = OrtEnvironment.getEnvironment();
        } catch (OrtException e) {
            throw new RuntimeException("初始化 ONNX Runtime 环境失败", e);
        }
    }

    /**
     * 加载 ONNX 模型(从 resources 读取)
     * @param modelPath 模型路径(resources 下)
     * @return OrtSession 推理会话
     */
    public static OrtSession loadModel(String modelPath) {
        try {
            // 将 resources 中的模型复制到临时文件(ONNX Runtime 需文件路径)
            InputStream is = OnnxModelLoader.class.getClassLoader().getResourceAsStream(modelPath);
            Path tempModel = Files.createTempFile("onnx_model_", ".onnx");
            tempModel.toFile().deleteOnExit(); // 程序退出时删除临时文件
            Files.copy(is, tempModel, StandardCopyOption.REPLACE_EXISTING);

            // 创建推理会话
            OrtSession.SessionOptions options = new OrtSession.SessionOptions();
            // 可选:设置优化级别(0-9)
            options.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
            // 可选:设置执行模式(SEQUENTIAL/PARALLEL)
            options.setExecutionMode(OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL);

            return env.createSession(tempModel.toString(), options);
        } catch (Exception e) {
            throw new RuntimeException("加载 ONNX 模型失败", e);
        }
    }
}

3.4 核心分类逻辑(ImageClassifier.java)

整合模型加载、图片预处理、推理执行、结果解析全流程:

package com.example;

import com.alibaba.fastjson2.JSONObject;
import com.microsoft.onnxruntime.OrtException;
import com.microsoft.onnxruntime.OrtSession;
import com.microsoft.onnxruntime.Tensor;
import com.example.utils.ImagePreprocessor;
import com.example.utils.LabelLoader;

import java.nio.FloatBuffer;
import java.util.Collections;
import java.util.List;
import java.util.Map;

public class ImageClassifier {
    // 模型路径和标签路径
    private static final String MODEL_PATH = "models/resnet50-v1-12.onnx";
    private static final String LABEL_PATH = "labels/imagenet_labels.txt";

    private final OrtSession session;
    private final List<String> labels;

    // 单例模式
    private static volatile ImageClassifier instance;

    private ImageClassifier() {
        this.session = OnnxModelLoader.loadModel(MODEL_PATH);
        this.labels = LabelLoader.loadLabels(LABEL_PATH);
    }

    public static ImageClassifier getInstance() {
        if (instance == null) {
            synchronized (ImageClassifier.class) {
                if (instance == null) {
                    instance = new ImageClassifier();
                }
            }
        }
        return instance;
    }

    /**
     * 图片分类推理
     * @param imagePath 图片路径
     * @return 分类结果(JSON 格式:标签、置信度)
     * @throws OrtException ONNX Runtime 异常
     */
    public JSONObject classify(String imagePath) throws OrtException {
        // 1. 图片预处理
        float[][] inputData = ImagePreprocessor.preprocess(imagePath);
        long[] inputShape = {1, 3, 224, 224}; // ResNet-50 输入形状

        // 2. 创建输入张量
        Tensor<Float> inputTensor = Tensor.createTensor(env, FloatBuffer.wrap(flatten(inputData)), inputShape);

        // 3. 执行推理
        OrtSession.Result result = session.run(Collections.singletonMap("data_0", inputTensor));

        // 4. 解析输出结果
        float[] outputData = ((float[][]) result.get(0).getValue())[0];
        int maxIndex = getMaxIndex(outputData);
        String label = labels.get(maxIndex);
        float confidence = outputData[maxIndex];

        // 5. 封装结果
        JSONObject resultJson = new JSONObject();
        resultJson.put("imagePath", imagePath);
        resultJson.put("label", label);
        resultJson.put("confidence", String.format("%.4f", confidence));
        resultJson.put("labelIndex", maxIndex);

        // 释放资源
        inputTensor.close();
        result.close();

        return resultJson;
    }

    // 扁平化二维数组
    private float[] flatten(float[][] data) {
        float[] flat = new float[data.length * data[0].length];
        int idx = 0;
        for (float[] row : data) {
            System.arraycopy(row, 0, flat, idx, row.length);
            idx += row.length;
        }
        return flat;
    }

    // 获取置信度最大的索引
    private int getMaxIndex(float[] array) {
        int maxIndex = 0;
        float maxValue = array[0];
        for (int i = 1; i < array.length; i++) {
            if (array[i] > maxValue) {
                maxValue = array[i];
                maxIndex = i;
            }
        }
        return maxIndex;
    }

    // 关闭会话
    public void close() {
        try {
            session.close();
            env.close();
        } catch (OrtException e) {
            e.printStackTrace();
        }
    }

    // 测试入口
    public static void main(String[] args) {
        if (args.length == 0) {
            System.out.println("请传入图片路径,例如:java -jar xxx.jar test.jpg");
            return;
        }
        ImageClassifier classifier = ImageClassifier.getInstance();
        try {
            JSONObject result = classifier.classify(args[0]);
            System.out.println("分类结果:" + result.toJSONString());
        } catch (OrtException e) {
            e.printStackTrace();
        } finally {
            classifier.close();
        }
    }
}

四、测试与验证

4.1 编写测试类(ImageClassifierTest.java)

package com.example;

import com.alibaba.fastjson2.JSONObject;
import com.microsoft.onnxruntime.OrtException;
import org.junit.Test;

public class ImageClassifierTest {
    @Test
    public void testClassify() {
        // 替换为你的测试图片路径(如 cat.jpg、dog.jpg)
        String imagePath = "src/test/resources/test.jpg";
        ImageClassifier classifier = ImageClassifier.getInstance();
        try {
            JSONObject result = classifier.classify(imagePath);
            System.out.println("测试结果:" + result);
        } catch (OrtException e) {
            e.printStackTrace();
        } finally {
            classifier.close();
        }
    }
}

4.2 运行测试

  1. 将测试图片(如 test.jpg,包含一只猫)放入 src/test/resources/ 目录;
  2. 运行 ImageClassifierTest,输出示例:
测试结果:{"imagePath":"src/test/resources/test.jpg","label":"tiger cat","confidence":"0.9876","labelIndex":282}

4.3 打包运行

执行 mvn clean package 打包,生成可执行 JAR 包,运行命令:

java -jar onnxruntime-java-image-classification-1.0-SNAPSHOT.jar /path/to/your/image.jpg

五、关键注意事项

5.1 模型输入输出名称

ResNet-50 ONNX 模型的输入名称为 data_0,输出名称为 softmax_1,需根据实际模型调整(可通过 Netron 查看模型结构)。

5.2 硬件加速

若需 GPU 加速:

  • 替换依赖为 onnxruntime-gpu
  • 在 OrtSession.SessionOptions 中设置 GPU 设备:
options.addCUDA(0); // 使用第0块GPU
  • 确保本地安装对应版本的 CUDA 和 cuDNN。

5.3 内存管理

ONNX Runtime 的 TensorOrtSession.ResultOrtSession 等资源需手动关闭,避免内存泄漏。

5.4 预处理精度

图片预处理的均值、标准差、尺寸需与模型训练时一致,否则会导致推理结果错误。

六、总结

本文以图片分类为例,完整实现了 ONNX Runtime for Java 的端到端推理流程,涵盖模型加载、图片预处理、推理执行、结果解析核心环节。通过该案例,你可以快速掌握:

  1. ONNX Runtime Java 版的基本使用方式;
  2. 深度学习模型推理的通用流程(预处理→推理→后处理);
  3. Java 结合 OpenCV 进行图片预处理的方法。

该方案可扩展至其他 ONNX 模型(如目标检测、语义分割),仅需调整预处理逻辑和结果解析规则。ONNX Runtime 为 Java 提供了高性能的 AI 推理能力,可广泛应用于企业级 AI 应用、边缘计算等场景。

参考资料

  1. ONNX Runtime Java 官方文档
  2. ONNX Model Zoo
  3. OpenCV Java 文档
  4. ResNet-50 模型说明
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

canjun_wen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值