在springboot框架上构建onnx图像识别接口

一、思路

1.获取模型文件,获取他所需input参数,本文使用的yoloV8提供的demo模型,为(1,3,640,640),80个标签

2.从接口获取图片,转成模型所需input参数

3.把图片数据放入模型中运算,得到结果集

4.对结果集进行过滤和非极大值抑制

5.根据结果集对原图画框打标签,返回图片

二、获取模型,构建模型环境

需要添加依赖


        <dependency>
            <groupId>com.microsoft.onnxruntime</groupId>
            <artifactId>onnxruntime</artifactId>
            <version>1.16.0</version>
        </dependency>

 具体代码:

public class OnnxExecutor {
    static{
        // 加载opencv需要的动态库
        System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
    }
    /**
     * 环境
     */
    public OrtEnvironment env;
    /**
     * session
     */
    public OrtSession session;
    /**
     * 模型参数名称
     */
    public JSONObject names;
    /**
     *
     */
    public long count;
    /**
     *
     */
    public long channels;
    /**
     * 模型高(640)
      */
    public long netHeight;
    /**
     * 模型宽(640)
     */
    public long netWidth;
    /**
     * 图片高
     */
    public float srcw;
    /**
     * 图片宽
     */
    public float srch;
    /**
     * 图像原数据
     */
    public Mat src;
    /**
     * 标注颜色
     */
    public static Scalar color = new Scalar(0, 0, 255);
    public static int tickness = 2;

    /**
     * 检测框筛选阈值,参考 detect.py 中的设置
     */
    public float confThreshold = 0.75f;
    /**
     * 交集阈值
     */
    public float nmsThreshold = 0.5f;

    public OnnxExecutor(String modelPath, Float nmsThreshold) {
        this(modelPath);
        if (Objects.nonNull(nmsThreshold)) {
            this.nmsThreshold = nmsThreshold;
        }
    }

    // modelPath我使用的是绝对地址,System.getProperty("user.dir") + "/src/main/resources/static/best.onnx";
    public OnnxExecutor(String modelPath) {
        try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
            this.env = env;
            // 初始化session
            session = env.createSession(modelPath, new OrtSession.SessionOptions());
            OnnxModelMetadata metadata = session.getMetadata();
            Map<String, NodeInfo> infoMap = session.getInputInfo();
            TensorInfo nodeInfo = (TensorInfo)infoMap.get("images").getInfo();
            String nameClass = metadata.getCustomMetadata().get("names");
            // 标签数据
            names = JSONObject.parseObject(nameClass.replace("\"","\"\""));
            //1 模型每次处理一张图片
            count = nodeInfo.getShape()[0];
            //3 模型通道数
            channels = nodeInfo.getShape()[1];
            //640 模型高
            netHeight = nodeInfo.getShape()[2];
            //640 模型宽
            netWidth = nodeInfo.getShape()[3];
            System.out.println(names.get(0));
        } catch (Exception e) {
            log.error(e.getMessage());
            throw new RuntimeException();
        } finally {
        }
    }
}

三、把图片转换成对应参数格式

图片处理需要依赖OpenCv,我使用的4.8.0版本,从官网下载的jar包和dll。依稀记得4.5.0可以直接使用maven下载,没有试过

        <dependency>
            <groupId>org.opencv</groupId>
            <artifactId>opencv</artifactId>
            <version>4.8.0</version><scope>system</scope>
            <systemPath>${project.basedir}/src/main/resources/lib/opencv/opencv-480.jar</systemPath>
        </dependency>

加载动态库

    static{
        // 加载opencv需要的动态库
        System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
    }

这个里有个坑,在使用动态库的时候,不能添加spring-boot-devtools依赖,会导致在图像转换的时候报找不到本地函数n_convertTo,具体原因不是很清楚。

图片转换及使用模型获取结果集

public String identifyImg(MultipartFile input){
        //Image图片
        try {
            // 获取图片类型(PNG/JPRG)
            Path path = Paths.get(input.getOriginalFilename());
            // jpeg jpg bmp gif wbmp
            String imgType = Files.probeContentType(path);
            String[] split = imgType.split("\\/");
            if (split.length > 1) {
                imgType = split[1];
            }
            // 1.图片格式转换
            BufferedImage image = ImageIO.read(input.getInputStream());
//            new File();
            srcw = image.getWidth();
            srch = image.getHeight();
            // 将BufferedImage转换为字节数组
            ByteArrayOutputStream baos = new ByteArrayOutputStream();
            ImageIO.write(image, imgType, baos);
            // mat:矩阵
            src = new Mat(image.getHeight(), image.getWidth(), CvType.CV_8UC3);
            src.put(0, 0, ((DataBufferByte) image.getRaster().getDataBuffer()).getData());
            // 需处理图像数据
            Mat dst = new Mat();
            Imgproc.resize(src, dst, new Size(netWidth, netHeight));
            // 2.图片转换为模型所需参数
            Imgproc.cvtColor(dst, dst, Imgproc.COLOR_BGR2RGB);
            // 归一化 0-255 转 0-1
            dst.convertTo(dst, CvType.CV_32FC1, 1. / 255);
            float[] whc = new float[ Long.valueOf(channels).intValue() * Long.valueOf(netWidth).intValue() * Long.valueOf(netHeight).intValue() ];
            dst.get(0, 0, whc);
            float[] chw = new float[src.length];
            int j = 0;
            for (int ch = 0; ch < 3; ++ch) {
                for (int i = ch; i < src.length; i += 3) {
                    chw[j] = src[i];
                    j++;
                }
            }
            OnnxTensor tensor = null;
            try {
                tensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(chw), new long[]{count,channels,netWidth,netHeight});
            }
            catch (Exception e){
                e.printStackTrace();
                System.exit(0);
            }
            // 执行
            OrtSession.Result result = session.run(Collections.singletonMap("images", tensor));
            System.out.println("res Data: "+result.get(0));
            // 结果处理
            OnnxTensor res = (OnnxTensor)result.get(0);
            // [图片数量][84 = x+y+w+h+80个数据集类别][8400]
            float[][][] dataRes = (float[][][])res.getValue();
            // 获取第一张图片的结果
            float[][] data = dataRes[0];
        } catch (Exception e) {
            throw new RuntimeException(e);
        } finally {

        }
        return "ok";
    }

四、对结果集进行处理

首先对于模型返回的数据需要有个大体的认识,模型的返回值是由模型本身决定的,例子中他的输出结构为[图片数量][坐标(x,y,w,h)加标签匹配度][8400(固定值,模型各尺度输出特征图叠加之后的结果)]

private Map<Object,Object> method2(float[][] data) {
        // 将矩阵转置
        // 先将xywh部分转置
        float rawData[][]=new float[data[0].length][6];
        System.out.println(data.length-1);
        for(int i=0;i<4;i++){
            for(int j=0;j<data[0].length;j++){
                rawData[j][i]=data[i][j];
            }
        }
        // 保存每个检查框置信值最高的类型置信值和该类型下标
        for(int i=0;i<data[0].length;i++){
            for(int j=4;j<data.length;j++){
                if(rawData[i][4]<data[j][i]){
                    //置信值
                    rawData[i][4]=data[j][i];
                    //类型编号
                    rawData[i][5]=j-4;
                }
            }
        }
        List<ArrayList<Float>> boxes=new LinkedList<ArrayList<Float>>();
        ArrayList<Float> box=null;
        // 置信值过滤,xywh转xyxy
        for(float[] d:rawData){
            // 置信值过滤
            if(d[4]>confThreshold){
                // 给的坐标系是中间点的
                // xywh(xy为中心点)转xyxy
                d[0] = d[0] - d[2]/2;
                d[1] = d[1] - d[3]/2;
                d[2] = d[0] + d[2];
                d[3] = d[1] + d[3];

                // 置信值符合的进行插入法排序保存
                box=new ArrayList<Float>();
                for(float num:d) {
                    box.add(num);
                }
                if(boxes.size()==0){
                    boxes.add(box);
                }else {
                    int i;
                    for(i=0;i<boxes.size();i++){
                        if(box.get(4)>boxes.get(i).get(4)){
                            boxes.add(i,box);
                            break;
                        }
                    }
                    // 插入到最后
                    if(i==boxes.size()){
                        boxes.add(box);
                    }
                }
            }
        }

        // 每个框分别有x1、x1、x2、y2、conf、class
        //System.out.println(boxes);

        // 非极大值抑制
        int[] indexs=new int[boxes.size()];
        Arrays.fill(indexs,1);
        //用于标记1保留,0删除
        for(int cur=0;cur<boxes.size();cur++){
            if(indexs[cur]==0){
                continue;
            }
            //当前框代表该类置信值最大的框
            ArrayList<Float> curMaxConf=boxes.get(cur);
            for(int i=cur+1;i<boxes.size();i++){
                if(indexs[i]==0){
                    continue;
                }
                float classIndex=boxes.get(i).get(5);
                // 两个检测框都检测到同一类数据,通过iou来判断是否检测到同一目标,这就是非极大值抑制
                if(classIndex==curMaxConf.get(5)){
                    float x1=curMaxConf.get(0);
                    float y1=curMaxConf.get(1);
                    float x2=curMaxConf.get(2);
                    float y2=curMaxConf.get(3);
                    float x3=boxes.get(i).get(0);
                    float y3=boxes.get(i).get(1);
                    float x4=boxes.get(i).get(2);
                    float y4=boxes.get(i).get(3);
                    //将几种不相交的情况排除。提示:x1y1、x2y2、x3y3、x4y4对应两框的左上角和右下角
                    if(x1>x4||x2<x3||y1>y4||y2<y3){
                        continue;
                    }
                    // 两个矩形的交集面积
                    float intersectionWidth =Math.max(x1, x3) - Math.min(x2, x4);
                    float intersectionHeight=Math.max(y1, y3) - Math.min(y2, y4);
                    float intersectionArea =Math.max(0,intersectionWidth * intersectionHeight);
                    // 两个矩形的并集面积
                    float unionArea = (x2-x1)*(y2-y1)+(x4-x3)*(y4-y3)-intersectionArea;
                    // 计算IoU
                    float iou = intersectionArea / unionArea;
                    // 对交并比超过阈值的标记
                    indexs[i]=iou>nmsThreshold?0:1;
                    //System.out.println(cur+" "+i+" class"+curMaxConf.get(5)+" "+classIndex+"  u:"+unionArea+" i:"+intersectionArea+"  iou:"+ iou);
                }
            }
        }


        List<ArrayList<Float>> resBoxes=new LinkedList<ArrayList<Float>>();
        for(int index=0;index<indexs.length;index++){
            if(indexs[index]==1) {
                resBoxes.add(boxes.get(index));
            }
        }
        boxes=resBoxes;

        float scaleW=srcw/netWidth;
        float scaleH=srch/netHeight;
        System.out.println("boxes.size : "+boxes.size());
        for(ArrayList<Float> box1:boxes){
            box1.set(0,box1.get(0)*scaleW);
            box1.set(1,box1.get(1)*scaleH);
            box1.set(2,box1.get(2)*scaleW);
            box1.set(3,box1.get(3)*scaleH);
        }
        System.out.println("boxes: "+boxes);
        //detect(boxes);
        Map<Object,Object> map=new HashMap<Object,Object>();
        map.put("boxes",boxes);
        map.put("classNames",names);
        return map;
    }

处理之后的map就是图片上需要标记的框及识别的标签。

五、画框,输出结果

后面就是根据map数据,在原图上画框

    public Mat showDetect(Map<Object,Object> map){
        List<ArrayList<Float>> boxes=(List<ArrayList<Float>>)map.get("boxes");
        JSONObject names=(JSONObject) map.get("classNames");
        // 画框,加数据
        for(ArrayList<Float> box:boxes){
            float x1=box.get(0);
            float y1=box.get(1);
            float x2=box.get(2);
            float y2=box.get(3);
            float config=box.get(4);
            String className=(String)names.get((int)box.get(5).intValue());;
            Point point1=new Point(x1,y1);
            Point point2=new Point(x2,y2);
            Imgproc.rectangle(src, point1, point2 , color,2);
            String conf=new DecimalFormat("#.###").format(config);
            Imgproc.putText(src,className+" "+conf,new Point(x1,y1-5),0,0.5,new Scalar(255,0,0),1);
        }
        return src;
    }

我在这里是直接把mat转成流保存到本地了,各位就拿着mat放到response的outputstream里扔出去。

效果参考

狗头保命:图片来自于百度图片

结尾

这里只是做一个demo学习用,其中看不懂的我也不懂,各位自行百度。 

参考文章:ONNX格式模型 学习笔记 (onnxRuntime部署)---用java调用yolov8模型来举例_yolov8 java-优快云博客

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值