一、导出一个onnx模型
这里训练了一个简单的线性回归模型
通过SerializeToString完成导出。
from sklearn.linear_model import LinearRegression
import numpy as np
import onnx
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
# 训练一个简单的线性回归模型
X = np.array([[1], [2], [3], [4]])
y = np.array([2, 4, 6, 8])
model = LinearRegression()
model.fit(X, y)
initial_type = [('float_input', FloatTensorType([None, 1]))]
onnx_model = convert_sklearn(model, initial_types = initial_type)
with open('linear_regression.onnx', 'wb') as f:
f.write(onnx_model.SerializeToString())
二、java项目中maven添加依赖
<dependency>
<groupId>ai.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.12.0</version>
</dependency>
三、代码调用
import ai.onnxruntime.*;
import java.nio.FloatBuffer;
import java.util.Map;
public class ONNXModelCaller {
public static void main(String[] args) throws OrtException {
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession session = env.createSession("linear_regression.onnx", new OrtSession.SessionOptions());
float[] inputData = {5.0f};
Tensor inputTensor = Tensor.createFromByteBuffer(FloatBuffer.wrap(inputData), new long[]{1, 1});
OrtSession.Result results = session.run(Map.of("float_input", inputTensor));
Tensor outputTensor = results.get(0);
float[] outputData = (float[]) outputTensor.getData();
System.out.println("Prediction: " + outputData[0]);
results.close();
session.close();
env.close();
}
}
基于深度学习框架(如 PyTorch 或 TensorFlow)构建,且可以转换为 ONNX 格式,这种方法可以提供高效的跨语言部署