ONNX Runtime 是一款跨平台、高性能的推理引擎,能够高效运行 ONNX 格式的深度学习模型。Java 作为企业级开发的主流语言,结合 ONNX Runtime 可轻松实现 AI 模型的端到端推理。本文将以图片分类为核心场景,从零搭建一个基于 ONNX Runtime for Java 的完整项目,详细讲解环境配置、模型加载、数据预处理、推理执行及结果解析全流程。
一、实战背景与准备工作
1.1 核心目标
使用预训练的 ResNet-50 ONNX 模型,通过 Java + ONNX Runtime 实现对任意图片的分类推理,输出图片的类别标签及置信度。
1.2 环境依赖
- JDK 8 及以上(推荐 11/17,兼容 ONNX Runtime 最新版本)
- Maven 3.6+(项目构建工具)
- ONNX Runtime Java 包(onnxruntime-gpu/onnxruntime,根据硬件选择)
- 预训练 ResNet-50 ONNX 模型(可从 ONNX Model Zoo 下载)
- 辅助依赖:OpenCV(图片预处理)、FastJSON(结果格式化)
1.3 模型与标签准备
- 下载 ResNet-50 ONNX 模型:从 ONNX Model Zoo 下载
resnet50-v1-12.onnx,保存至项目src/main/resources/models目录。 - 下载 ImageNet 标签文件:创建
src/main/resources/labels/imagenet_labels.txt,保存 1000 类 ImageNet 标签(可从 GitHub 转换为文本格式)。
二、项目搭建与依赖配置
2.1 创建 Maven 项目
通过 IDEA/Eclipse 创建 Maven 项目,核心 pom.xml 依赖配置如下:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.example</groupId>
<artifactId>onnxruntime-java-image-classification</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
<onnxruntime.version>1.17.0</onnxruntime.version>
<opencv.version>4.8.0</opencv.version>
<fastjson.version>2.0.43</fastjson.version>
</properties>
<dependencies>
<!-- ONNX Runtime Java 核心依赖(CPU版),GPU版替换为 onnxruntime-gpu -->
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>${onnxruntime.version}</version>
</dependency>
<!-- OpenCV 用于图片预处理(读取、缩放、归一化) -->
<dependency>
<groupId>org.openpnp</groupId>
<artifactId>opencv</artifactId>
<version>${opencv.version}</version>
</dependency>
<!-- FastJSON 用于结果格式化输出 -->
<dependency>
<groupId>com.alibaba.fastjson2</groupId>
<artifactId>fastjson2</artifactId>
<version>${fastjson.version}</version>
</dependency>
<!-- 测试依赖 -->
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13.2</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<!-- 打包时包含依赖 -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>3.4.1</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
2.2 目录结构
最终项目目录如下:
onnxruntime-java-image-classification/
├── src/
│ ├── main/
│ │ ├── java/
│ │ │ └── com/
│ │ │ └── example/
│ │ │ ├── ImageClassifier.java // 核心分类逻辑
│ │ │ ├── OnnxModelLoader.java // 模型加载工具
│ │ │ └── utils/
│ │ │ ├── ImagePreprocessor.java // 图片预处理
│ │ │ └── LabelLoader.java // 标签加载工具
│ │ └── resources/
│ │ ├── models/
│ │ │ └── resnet50-v1-12.onnx // ResNet-50 模型
│ │ └── labels/
│ │ └── imagenet_labels.txt // ImageNet 标签
│ └── test/
│ └── java/
│ └── com/
│ └── example/
│ └── ImageClassifierTest.java // 测试类
└── pom.xml
三、核心功能实现
3.1 标签加载工具(LabelLoader.java)
读取 imagenet_labels.txt,将标签存储为 List,方便后续根据索引查询类别:
package com.example.utils;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
public class LabelLoader {
/**
* 加载标签文件
* @param labelPath 标签文件路径(resources 下)
* @return 标签列表
*/
public static List<String> loadLabels(String labelPath) {
List<String> labels = new ArrayList<>();
try (InputStream is = LabelLoader.class.getClassLoader().getResourceAsStream(labelPath);
BufferedReader br = new BufferedReader(new InputStreamReader(is))) {
String line;
while ((line = br.readLine()) != null) {
labels.add(line.trim());
}
} catch (Exception e) {
throw new RuntimeException("加载标签文件失败", e);
}
return labels;
}
}
3.2 图片预处理工具(ImagePreprocessor.java)
ResNet-50 要求输入为 (1, 3, 224, 224) 的张量,且需进行归一化(均值 [0.485, 0.456, 0.406],标准差 [0.229, 0.224, 0.225])。使用 OpenCV 实现图片读取、缩放、通道转换、归一化:
package com.example.utils;
import org.opencv.core.Core;
import org.opencv.core.Mat;
import org.opencv.core.Size;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
public class ImagePreprocessor {
// 静态加载 OpenCV 库
static {
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
}
// ResNet-50 输入尺寸
private static final int INPUT_WIDTH = 224;
private static final int INPUT_HEIGHT = 224;
// 归一化均值和标准差(RGB 顺序)
private static final double[] MEAN = {0.485, 0.456, 0.406};
private static final double[] STD = {0.229, 0.224, 0.225};
/**
* 图片预处理:读取→缩放→BGR转RGB→归一化→转张量格式
* @param imagePath 图片路径
* @return 预处理后的浮点型数组 (1, 3, 224, 224)
*/
public static float[][] preprocess(String imagePath) {
// 1. 读取图片(OpenCV 默认 BGR 通道)
Mat image = Imgcodecs.imread(imagePath);
if (image.empty()) {
throw new RuntimeException("图片读取失败:" + imagePath);
}
// 2. 缩放至 224x224
Mat resizedImage = new Mat();
Imgproc.resize(image, resizedImage, new Size(INPUT_WIDTH, INPUT_HEIGHT));
// 3. 转换为浮点型,归一化到 [0,1]
resizedImage.convertTo(resizedImage, Core.CV_32F, 1.0 / 255.0);
// 4. 分离通道(BGR → RGB)并归一化
Mat[] channels = new Mat[3];
Core.split(resizedImage, channels);
for (int i = 0; i < 3; i++) {
// OpenCV 是 BGR,转换为 RGB:i=0→B→对应索引2,i=1→G→索引1,i=2→R→索引0
int rgbIndex = 2 - i;
Core.subtract(channels[i], MEAN[rgbIndex], channels[i]);
Core.divide(channels[i], STD[rgbIndex], channels[i]);
}
// 5. 重组为 (3, 224, 224) 的张量
Mat rgbImage = new Mat();
Core.merge(channels, rgbImage);
// 6. 转换为 float 数组,格式为 [1, 3, 224, 224]
float[][] inputTensor = new float[1][3 * INPUT_WIDTH * INPUT_HEIGHT];
float[] pixels = new float[(int) (rgbImage.total() * rgbImage.channels())];
rgbImage.get(0, 0, pixels);
inputTensor[0] = pixels;
// 释放资源
image.release();
resizedImage.release();
rgbImage.release();
for (Mat ch : channels) {
ch.release();
}
return inputTensor;
}
}
3.3 ONNX 模型加载工具(OnnxModelLoader.java)
封装 ONNX Runtime 模型加载逻辑,创建推理会话(InferenceSession):
package com.example;
import com.microsoft.onnxruntime.OrtEnvironment;
import com.microsoft.onnxruntime.OrtException;
import com.microsoft.onnxruntime.OrtSession;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
public class OnnxModelLoader {
private static OrtEnvironment env;
static {
// 初始化 ONNX Runtime 环境
try {
env = OrtEnvironment.getEnvironment();
} catch (OrtException e) {
throw new RuntimeException("初始化 ONNX Runtime 环境失败", e);
}
}
/**
* 加载 ONNX 模型(从 resources 读取)
* @param modelPath 模型路径(resources 下)
* @return OrtSession 推理会话
*/
public static OrtSession loadModel(String modelPath) {
try {
// 将 resources 中的模型复制到临时文件(ONNX Runtime 需文件路径)
InputStream is = OnnxModelLoader.class.getClassLoader().getResourceAsStream(modelPath);
Path tempModel = Files.createTempFile("onnx_model_", ".onnx");
tempModel.toFile().deleteOnExit(); // 程序退出时删除临时文件
Files.copy(is, tempModel, StandardCopyOption.REPLACE_EXISTING);
// 创建推理会话
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
// 可选:设置优化级别(0-9)
options.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);
// 可选:设置执行模式(SEQUENTIAL/PARALLEL)
options.setExecutionMode(OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL);
return env.createSession(tempModel.toString(), options);
} catch (Exception e) {
throw new RuntimeException("加载 ONNX 模型失败", e);
}
}
}
3.4 核心分类逻辑(ImageClassifier.java)
整合模型加载、图片预处理、推理执行、结果解析全流程:
package com.example;
import com.alibaba.fastjson2.JSONObject;
import com.microsoft.onnxruntime.OrtException;
import com.microsoft.onnxruntime.OrtSession;
import com.microsoft.onnxruntime.Tensor;
import com.example.utils.ImagePreprocessor;
import com.example.utils.LabelLoader;
import java.nio.FloatBuffer;
import java.util.Collections;
import java.util.List;
import java.util.Map;
public class ImageClassifier {
// 模型路径和标签路径
private static final String MODEL_PATH = "models/resnet50-v1-12.onnx";
private static final String LABEL_PATH = "labels/imagenet_labels.txt";
private final OrtSession session;
private final List<String> labels;
// 单例模式
private static volatile ImageClassifier instance;
private ImageClassifier() {
this.session = OnnxModelLoader.loadModel(MODEL_PATH);
this.labels = LabelLoader.loadLabels(LABEL_PATH);
}
public static ImageClassifier getInstance() {
if (instance == null) {
synchronized (ImageClassifier.class) {
if (instance == null) {
instance = new ImageClassifier();
}
}
}
return instance;
}
/**
* 图片分类推理
* @param imagePath 图片路径
* @return 分类结果(JSON 格式:标签、置信度)
* @throws OrtException ONNX Runtime 异常
*/
public JSONObject classify(String imagePath) throws OrtException {
// 1. 图片预处理
float[][] inputData = ImagePreprocessor.preprocess(imagePath);
long[] inputShape = {1, 3, 224, 224}; // ResNet-50 输入形状
// 2. 创建输入张量
Tensor<Float> inputTensor = Tensor.createTensor(env, FloatBuffer.wrap(flatten(inputData)), inputShape);
// 3. 执行推理
OrtSession.Result result = session.run(Collections.singletonMap("data_0", inputTensor));
// 4. 解析输出结果
float[] outputData = ((float[][]) result.get(0).getValue())[0];
int maxIndex = getMaxIndex(outputData);
String label = labels.get(maxIndex);
float confidence = outputData[maxIndex];
// 5. 封装结果
JSONObject resultJson = new JSONObject();
resultJson.put("imagePath", imagePath);
resultJson.put("label", label);
resultJson.put("confidence", String.format("%.4f", confidence));
resultJson.put("labelIndex", maxIndex);
// 释放资源
inputTensor.close();
result.close();
return resultJson;
}
// 扁平化二维数组
private float[] flatten(float[][] data) {
float[] flat = new float[data.length * data[0].length];
int idx = 0;
for (float[] row : data) {
System.arraycopy(row, 0, flat, idx, row.length);
idx += row.length;
}
return flat;
}
// 获取置信度最大的索引
private int getMaxIndex(float[] array) {
int maxIndex = 0;
float maxValue = array[0];
for (int i = 1; i < array.length; i++) {
if (array[i] > maxValue) {
maxValue = array[i];
maxIndex = i;
}
}
return maxIndex;
}
// 关闭会话
public void close() {
try {
session.close();
env.close();
} catch (OrtException e) {
e.printStackTrace();
}
}
// 测试入口
public static void main(String[] args) {
if (args.length == 0) {
System.out.println("请传入图片路径,例如:java -jar xxx.jar test.jpg");
return;
}
ImageClassifier classifier = ImageClassifier.getInstance();
try {
JSONObject result = classifier.classify(args[0]);
System.out.println("分类结果:" + result.toJSONString());
} catch (OrtException e) {
e.printStackTrace();
} finally {
classifier.close();
}
}
}
四、测试与验证
4.1 编写测试类(ImageClassifierTest.java)
package com.example;
import com.alibaba.fastjson2.JSONObject;
import com.microsoft.onnxruntime.OrtException;
import org.junit.Test;
public class ImageClassifierTest {
@Test
public void testClassify() {
// 替换为你的测试图片路径(如 cat.jpg、dog.jpg)
String imagePath = "src/test/resources/test.jpg";
ImageClassifier classifier = ImageClassifier.getInstance();
try {
JSONObject result = classifier.classify(imagePath);
System.out.println("测试结果:" + result);
} catch (OrtException e) {
e.printStackTrace();
} finally {
classifier.close();
}
}
}
4.2 运行测试
- 将测试图片(如
test.jpg,包含一只猫)放入src/test/resources/目录; - 运行
ImageClassifierTest,输出示例:
测试结果:{"imagePath":"src/test/resources/test.jpg","label":"tiger cat","confidence":"0.9876","labelIndex":282}
4.3 打包运行
执行 mvn clean package 打包,生成可执行 JAR 包,运行命令:
java -jar onnxruntime-java-image-classification-1.0-SNAPSHOT.jar /path/to/your/image.jpg
五、关键注意事项
5.1 模型输入输出名称
ResNet-50 ONNX 模型的输入名称为 data_0,输出名称为 softmax_1,需根据实际模型调整(可通过 Netron 查看模型结构)。
5.2 硬件加速
若需 GPU 加速:
- 替换依赖为
onnxruntime-gpu; - 在
OrtSession.SessionOptions中设置 GPU 设备:
options.addCUDA(0); // 使用第0块GPU
- 确保本地安装对应版本的 CUDA 和 cuDNN。
5.3 内存管理
ONNX Runtime 的 Tensor、OrtSession.Result、OrtSession 等资源需手动关闭,避免内存泄漏。
5.4 预处理精度
图片预处理的均值、标准差、尺寸需与模型训练时一致,否则会导致推理结果错误。
六、总结
本文以图片分类为例,完整实现了 ONNX Runtime for Java 的端到端推理流程,涵盖模型加载、图片预处理、推理执行、结果解析核心环节。通过该案例,你可以快速掌握:
- ONNX Runtime Java 版的基本使用方式;
- 深度学习模型推理的通用流程(预处理→推理→后处理);
- Java 结合 OpenCV 进行图片预处理的方法。
该方案可扩展至其他 ONNX 模型(如目标检测、语义分割),仅需调整预处理逻辑和结果解析规则。ONNX Runtime 为 Java 提供了高性能的 AI 推理能力,可广泛应用于企业级 AI 应用、边缘计算等场景。
1251

被折叠的 条评论
为什么被折叠?



