java调用ONNX模型

一、导出一个onnx模型

这里训练了一个简单的线性回归模型
通过SerializeToString完成导出。

from sklearn.linear_model import LinearRegression
import numpy as np
import onnx
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType

# 训练一个简单的线性回归模型
X = np.array([[1], [2], [3], [4]])
y = np.array([2, 4, 6, 8])
model = LinearRegression()
model.fit(X, y)

initial_type = [('float_input', FloatTensorType([None, 1]))]
onnx_model = convert_sklearn(model, initial_types = initial_type)
with open('linear_regression.onnx', 'wb') as f:
    f.write(onnx_model.SerializeToString())

二、java项目中maven添加依赖

<dependency>
    <groupId>ai.onnxruntime</groupId>
    <artifactId>onnxruntime</artifactId>
    <version>1.12.0</version>
</dependency>

三、代码调用

import ai.onnxruntime.*;
import java.nio.FloatBuffer;
import java.util.Map;

public class ONNXModelCaller {
    public static void main(String[] args) throws OrtException {
        OrtEnvironment env = OrtEnvironment.getEnvironment();
        OrtSession session = env.createSession("linear_regression.onnx", new OrtSession.SessionOptions());

        float[] inputData = {5.0f};
        Tensor inputTensor = Tensor.createFromByteBuffer(FloatBuffer.wrap(inputData), new long[]{1, 1});
        OrtSession.Result results = session.run(Map.of("float_input", inputTensor));
        Tensor outputTensor = results.get(0);
        float[] outputData = (float[]) outputTensor.getData();
        System.out.println("Prediction: " + outputData[0]);

        results.close();
        session.close();
        env.close();
    }
}

基于深度学习框架(如 PyTorch 或 TensorFlow)构建,且可以转换为 ONNX 格式,这种方法可以提供高效的跨语言部署

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值