android 使用 onnxruntime 部署 yolov5_face_landmark 人脸检测

下面是使用 opencv-camera,实时处理区域内人脸检测 android 推理 demo。

首先是整合 opcv-camera 进去:

为了方便直接将整个 opencv-android-sdk 全部导入:

 然后在原来的项目模块app中添加 opencv的 java 相关依赖,主要添加红色两行:
app/build.grandle

dependencies {
    implementation fileTree(dir: 'libs', include: ['*.jar'])
    implementation 'androidx.appcompat:appcompat:1.4.1'
    implementation 'com.google.android.material:material:1.5.0'
    implementation 'androidx.constraintlayout:constraintlayout:2.1.3'
    testImplementation 'junit:junit:4.13.2'
    androidTestImplementation 'androidx.test.ext:junit:1.1.3'
    androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0'
    implementation project(':opencvsdk')
}

最后在项目中要使用opencv的地方加载jni库,可以添加到 MainActivity 中:

System.loadLibrary("opencv_java4"); 或者 OpenCVLoader.initDebug();

要使用 opencv-camera,MainActivity 继承 CameraActivity,然后在回调函数中获取每一帧进行处理,比如下面对每一帧添加识别区域边框:

   // 获取每一帧回调数据
    private CameraBridgeViewBase.CvCameraViewListener2 cameraViewListener2 = new CameraBridgeViewBase.CvCameraViewListener2() {
        @Override
        public void onCameraViewStarted(int width, int height) {
            System.out.println("开始预览 width="+width+",height="+height);
            // 预览界面是 640*480,模型输入时 320*320,计算识别区域坐标
            int detection_x1 = (640 - OnnxUtil.w)/2;
            int detection_x2 = (640 - OnnxUtil.w)/2 + OnnxUtil.w;
            int detection_y1 = (480 - OnnxUtil.h)/2;
            int detection_y2 = (480 - OnnxUtil.h)/2 + OnnxUtil.h;;
            System.out.println("识别区域:"+"("+detection_x1+","+detection_y1+")"+"("+detection_x2+","+detection_y2+")");
            // 缓存识别区域两个点
            detection_p1 = new Point(detection_x1,detection_y1);
            detection_p2 = new Point(detection_x2,detection_y2);
            detection_box_color = new Scalar(255, 0, 0);
            detection_box_tickness = 2;
        }
        @Override
        public void onCameraViewStopped() {}
        @Override
        public Mat onCameraFrame(CameraBridgeViewBase.CvCameraViewFrame frame) {

            // 获取 cv::Mat
            Mat mat = frame.rgba();

            // 标注识别区域
            Imgproc.rectangle(mat, detection_p1, detection_p2,detection_box_color,detection_box_tickness);

            return mat;
        }
    };

在界面中开启预览:

ui资源:
  <org.opencv.android.JavaCamera2View
            android:id="@+id/camera_view"
            app:layout_constraintTop_toTopOf="parent"
            app:layout_constraintLeft_toLeftOf="parent"
            android:layout_width="match_parent"
            android:layout_height="match_parent">

    </org.opencv.android.JavaCamera2View>
java开启预览:
    private BaseLoaderCallback baseLoaderCallback = new BaseLoaderCallback(this) {
        @Override
        public void onManagerConnected(int status) {
            switch (status) {
                case LoaderCallbackInterface.SUCCESS: {
                    if (camera2View != null) {
                        // 设置前置还是后置摄像头 0后置 1前置
                        camera2View.setCameraIndex(cameraId);
                        // 注册每一帧回调
                        camera2View.setCvCameraViewListener(cameraViewListener2);
                        // 显示/关闭 帧率  disableFpsMeter/enableFpsMeter
                        // 要修改字体和颜色直接修改 FpsMeter 类即可
                        camera2View.enableFpsMeter();
                        // 设置视图宽高和模型一致减少resize操作,模型输入一般尺寸不大,这样相机渲染fps会更高
                        camera2View.setMaxFrameSize(win_w,win_h);
                        // 开启
                        camera2View.enableView();
                    }
                }
                break;
                default:
                    super.onManagerConnected(status);
                    break;
            }
        }
    };

下面是全部推理 MainActivity 代码:

package com.example.camera_opencv;


import android.content.pm.ActivityInfo;
import android.os.Bundle;
import android.view.WindowManager;
import com.example.camera_opencv.databinding.ActivityMainBinding;
import org.opencv.android.*;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Scalar;
import org.opencv.imgproc.Imgproc;

import java.util.Arrays;
import java.util.List;

public class MainActivity extends CameraActivity{

    // 动态库
    static {
        // 我们自己的jni
        System.loadLibrary("camera_opencv");
        // 新加的 opencv 的jni
        System.loadLibrary("opencv_java4");
    }

    private ActivityMainBinding binding;

    // 预览界面
    private JavaCamera2View camera2View;

    // 相机编号 0后置 1前置
    private int cameraId = 1;

    // 设置预览界面宽高,在次宽高基础上限制识别区域
    private int win_w = 640;
    private int win_h = 480;

    // 识别区域两个点
    private Point detection_p1;
    private Point detection_p2;
    private Scalar detection_box_color;
    private int detection_box_tickness;

    @Override
    protected void onCreate(Bundle savedInstanceState) {

        super.onCreate(savedInstanceState);
        binding = ActivityMainBinding.inflate(getLayoutInflater());
        setContentView(binding.getRoot());

        // 加载模型
        OnnxUtil.loadModule(getAssets());

        // 强制横屏
        setRequestedOrientation(ActivityInfo.SCREEN_ORIENTATION_LANDSCAPE);
        // 隐藏上方状态栏
        getWindow().setFlags(WindowManager.LayoutParams.FLAG_FULLSCREEN, WindowManager.LayoutParams.FLAG_FULLSCREEN);
        // 预览界面
        camera2View = findViewById(R.id.camera_view);
    }

    @Override
    protected List<? extends CameraBridgeViewBase> getCameraViewList() {
        return Arrays.asList(camera2View);
    }


    @Override
    public void onPause() {
        super.onPause();
        if (camera2View != null) {
            // 关闭预览
            camera2View.disableView();
        }
    }

    @Override
    public void onResume() {
        super.onResume();
        if (OpenCVLoader.initDebug()) {
            baseLoaderCallback.onManagerConnected(LoaderCallbackInterface.SUCCESS);
        } else {
            OpenCVLoader.initAsync(OpenCVLoader.OPENCV_VERSION, this, baseLoaderCallback);
        }
    }

    // 获取每一帧回调数据
    private CameraBridgeViewBase.CvCameraViewListener2 cameraViewListener2 = new CameraBridgeViewBase.CvCameraViewListener2() {
        @Override
        public void onCameraViewStarted(int width, int height) {
            System.out.println("开始预览 width="+width+",height="+height);
            // 预览界面是 640*480,模型输入时 320*320,计算识别区域坐标
            int detection_x1 = (640 - OnnxUtil.w)/2;
            int detection_x2 = (640 - OnnxUtil.w)/2 + OnnxUtil.w;
            int detection_y1 = (480 - OnnxUtil.h)/2;
            int detection_y2 = (480 - OnnxUtil.h)/2 + OnnxUtil.h;;
            System.out.println("识别区域:"+"("+detection_x1+","+detection_y1+")"+"("+detection_x2+","+detection_y2+")");
            // 缓存识别区域两个点
            detection_p1 = new Point(detection_x1,detection_y1);
            detection_p2 = new Point(detection_x2,detection_y2);
            detection_box_color = new Scalar(255, 0, 0);
            detection_box_tickness = 2;
        }
        @Override
        public void onCameraViewStopped() {}
        @Override
        public Mat onCameraFrame(CameraBridgeViewBase.CvCameraViewFrame frame) {

            // 获取 cv::Mat
            Mat mat = frame.rgba();

            // 标注识别区域
            Imgproc.rectangle(mat, detection_p1, detection_p2,detection_box_color,detection_box_tickness);

            // 推理并标注
            OnnxUtil.inference(mat,detection_p1,detection_p2);

            return mat;
        }
    };

    // 开启预览
    private BaseLoaderCallback baseLoaderCallback = new BaseLoaderCallback(this) {
        @Override
        public void onManagerConnected(int status) {
            switch (status) {
                case LoaderCallbackInterface.SUCCESS: {
                    if (camera2View != null) {
                        // 设置前置还是后置摄像头 0后置 1前置
                        camera2View.setCameraIndex(cameraId);
                        // 注册每一帧回调
                        camera2View.setCvCameraViewListener(cameraViewListener2);
                        // 显示/关闭 帧率  disableFpsMeter/enableFpsMeter
                        // 要修改字体和颜色直接修改 FpsMeter 类即可
                        camera2View.enableFpsMeter();
                        // 设置视图宽高和模型一致减少resize操作,模型输入一般尺寸不大,这样相机渲染fps会更高
                        camera2View.setMaxFrameSize(win_w,win_h);
                        // 开启
                        camera2View.enableView();
                    }
                }
                break;
                default:
                    super.onManagerConnected(status);
                    break;
            }
        }
    };

}

onnx 模型加载和推理代码:
使用的微软onnx推理框架:

implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'
implementation 'com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.release'

package com.example.camera_opencv;

import ai.onnxruntime.*;
import android.content.res.AssetManager;
import org.opencv.core.*;
import org.opencv.dnn.Dnn;
import org.opencv.imgproc.Imgproc;

import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.nio.FloatBuffer;
import java.util.*;

public class OnnxUtil {

    // onnxruntime 环境
    public static OrtEnvironment env;
    public static OrtSession session;

    // 模型输入
    public static int w = 0;
    public static int h = 0;
    public static int c = 3;

    // 标注颜色
    public static Scalar green = new Scalar(0, 255, 0);
    public static int tickness = 2;

    // 模型加载
    public static void loadModule(AssetManager assetManager){

        // 下面包含了多个模型
        // yolov5face-blazeface-640x640.onnx   3.4Mb
        // yolov5face-l-640x640.onnx   181Mb
        // yolov5face-m-640x640.onnx   	83Mb
        // yolov5face-n-0.5-320x320.onnx   2.5Mb
        // yolov5face-n-0.5-640x640.onnx   4.6Mb
        // yolov5face-n-640x640.onnx   9.5Mb
        // yolov5face-s-640x640.onnx   30Mb

        w = 320;
        h = 320;
        c = 3;

        try {
            // 模型输入:  input -> [1, 3, 320, 320] -> FLOAT
            // 模型输出:  output -> [1, 6300, 16] -> FLOAT
            InputStream inputStream = assetManager.open("yolov5face-n-0.5-320x320.onnx");
            ByteArrayOutputStream buffer = new ByteArrayOutputStream();
            int nRead;
            byte[] data = new byte[1024];
            while ((nRead = inputStream.read(data, 0, data.length)) != -1) {
                buffer.write(data, 0, nRead);
            }
            buffer.flush();
            byte[] module = buffer.toByteArray();
            System.out.println("开始加载模型");
            env = OrtEnvironment.getEnvironment();
            session = env.createSession(module, new OrtSession.SessionOptions());
            session.getInputInfo().entrySet().stream().forEach(n -> {
                String inputName = n.getKey();
                NodeInfo inputInfo = n.getValue();
                long[] shape = ((TensorInfo) inputInfo.getInfo()).getShape();
                String javaType = ((TensorInfo) inputInfo.getInfo()).type.toString();
                System.out.println("模型输入:  "+inputName + " -> " + Arrays.toString(shape) + " -> " + javaType);
            });
            session.getOutputInfo().entrySet().stream().forEach(n -> {
                String outputName = n.getKey();
                NodeInfo outputInfo = n.getValue();
                long[] shape = ((TensorInfo) outputInfo.getInfo()).getShape();
                String javaType = ((TensorInfo) outputInfo.getInfo()).type.toString();
                System.out.println("模型输出:  "+outputName + " -> " + Arrays.toString(shape) + " -> " + javaType);
            });
        } catch (Exception e) {
            e.printStackTrace();
        }

    }


    // 模型推理,输入原始图片和识别区域两个点
    public static void inference(Mat mat,Point detection_p1,Point detection_p2){

        int px = Double.valueOf(detection_p1.x).intValue();
        int py = Double.valueOf(detection_p1.y).intValue();

        // 提取rgb(chw存储)并做归一化,也就是 rrrrr bbbbb ggggg
        float[] chw = new float[c*h*w];
        // 像素点索引
        int index = 0;
        for(int j=0 ; j<h ; j++){
            for(int i=0 ; i<w ; i++){
                // 第j行,第i列,根据识别区域p1得到xy坐标的偏移,直接加就行
                double[] rgb = mat.get(j+py,i+px);
                // 缓存到 chw 中,mat 是 rgba 数据对应的下标 2103
                chw[index] = (float)(rgb[2]/255);//r
                chw[index + w * h * 1 ] = (float)(rgb[1]/255);//G
                chw[index + w * h * 2 ] = (float)(rgb[0]/255);//b
                index ++;
            }
        }

        // 创建张量并进行推理
        try {

            OnnxTensor tensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(chw), new long[]{1,c,h,w});
            OrtSession.Result output = session.run(Collections.singletonMap("input", tensor));
            float[][] out = ((float[][][])(output.get(0)).getValue())[0];

            ArrayList<float[]> datas = new ArrayList<>();

            for(int i=0;i<out.length;i++){

                float[] data = out[i];

                float score1 = data[4]; // 边框置信度
                float score2 = data[15];// 人脸置信度
                if( score1 >= 0.2 && score2>= 0.2){
                    // xywh 转 x1y1x2y2
                    float xx = data[0];
                    float yy = data[1];
                    float ww = data[2];
                    float hh = data[3];
                    float[] xyxy = xywh2xyxy(new float[]{xx,yy,ww,hh},w,h);
                    data[0] = xyxy[0];
                    data[1] = xyxy[1];
                    data[2] = xyxy[2];
                    data[3] = xyxy[3];
                    datas.add(data);
                }
            }

            // nms
            ArrayList<float[]> datas_after_nms = new ArrayList<>();
            while (!datas.isEmpty()){
                float[] max = datas.get(0);
                datas_after_nms.add(max);
                Iterator<float[]> it = datas.iterator();
                while (it.hasNext()) {
                    // nsm阈值
                    float[] obj = it.next();
                    double iou = calculateIoU(max,obj);
                    if (iou > 0.5f) {
                        it.remove();
                    }
                }
            }

            // 标注
            datas_after_nms.stream().forEach(n->{

                // x y w h score  中心点坐标和分数
                // x y 关键点坐标
                // x y 关键点坐标
                // x y 关键点坐标
                // x y 关键点坐标
                // x y 关键点坐标
                // cls_conf 人脸置信度

                // 画边框和关键点需要添加偏移
                int x1 = Float.valueOf(n[0]).intValue() + px;
                int y1 = Float.valueOf(n[1]).intValue() + py;
                int x2 = Float.valueOf(n[2]).intValue() + px;
                int y2 = Float.valueOf(n[3]).intValue() + py;
                Imgproc.rectangle(mat, new Point(x1, y1), new Point(x2, y2), green, tickness);

                float point1_x = Float.valueOf(n[5]).intValue() + px;// 关键点1
                float point1_y = Float.valueOf(n[6]).intValue() + py;//
                float point2_x = Float.valueOf(n[7]).intValue() + px;// 关键点2
                float point2_y = Float.valueOf(n[8]).intValue() + py;//
                float point3_x = Float.valueOf(n[9]).intValue() + px;// 关键点3
                float point3_y = Float.valueOf(n[10]).intValue() + py;//
                float point4_x = Float.valueOf(n[11]).intValue() + px;// 关键点4
                float point4_y = Float.valueOf(n[12]).intValue() + py;//
                float point5_x = Float.valueOf(n[13]).intValue() + px;// 关键点5
                float point5_y = Float.valueOf(n[14]).intValue() + py;//

                Imgproc.circle(mat, new Point(point1_x, point1_y), 1, green, tickness);
                Imgproc.circle(mat, new Point(point2_x, point2_y), 1, green, tickness);
                Imgproc.circle(mat, new Point(point3_x, point3_y), 1, green, tickness);
                Imgproc.circle(mat, new Point(point4_x, point4_y), 1, green, tickness);
                Imgproc.circle(mat, new Point(point5_x, point5_y), 1, green, tickness);

            });

        }
        catch (Exception e){
            e.printStackTrace();
        }
    }


    // 中心点坐标转 xin xmax ymin ymax
    public static float[] xywh2xyxy(float[] bbox,float maxWidth,float maxHeight) {
        // 中心点坐标
        float x = bbox[0];
        float y = bbox[1];
        float w = bbox[2];
        float h = bbox[3];
        // 计算
        float x1 = x - w * 0.5f;
        float y1 = y - h * 0.5f;
        float x2 = x + w * 0.5f;
        float y2 = y + h * 0.5f;
        // 限制在图片区域内
        return new float[]{
                x1 < 0 ? 0 : x1,
                y1 < 0 ? 0 : y1,
                x2 > maxWidth ? maxWidth:x2,
                y2 > maxHeight? maxHeight:y2};
    }

    // 计算两个框的交并比
    private static double calculateIoU(float[] box1, float[] box2) {
        //  getXYXY() 返回 xmin-0 ymin-1 xmax-2 ymax-3
        double x1 = Math.max(box1[0], box2[0]);
        double y1 = Math.max(box1[1], box2[1]);
        double x2 = Math.min(box1[2], box2[2]);
        double y2 = Math.min(box1[3], box2[3]);
        double intersectionArea = Math.max(0, x2 - x1 + 1) * Math.max(0, y2 - y1 + 1);
        double box1Area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1);
        double box2Area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1);
        double unionArea = box1Area + box2Area - intersectionArea;
        return intersectionArea / unionArea;
    }

}




如果是使用 native 开发的话,cpp 推理代码如下:
 

//
// Created by tyf on 2023/11/7.
//

#include <benchmark.h>
#include "infer.h"
#include "opencv2/core/core.hpp"
#include "opencv2/imgproc/imgproc.hpp"
#include <ctime>
#include <onnxruntime_cxx_api.h>
#include <android/log.h>

#define LOG_TAG "com.tyf.demo"
#define LOG_INFO(msg) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "%s", msg)


// 文字和背景size、颜色
float fontScale = 0.35;
int tickness = 2;
cv::Scalar color1 = cv::Scalar(255,255,255);
cv::Scalar color2 = cv::Scalar(0,0,0);
cv::Scalar color3 = cv::Scalar(255,0,0);
cv::Scalar color4 = cv::Scalar(255,0,0);

// 绘制 fps
int draw_fps(cv::Mat& rgb, double t1,bool background){
    double t = t1;
    // resolve moving average
    float avg_fps = 0.f;
    {
        static double t0 = 0.f;
        static float fps_history[10] = {0.f};
        double t1 = ncnn::get_current_time();
        if (t0 == 0.f){
            t0 = t1;
            return 0;
        }
        float fps = 1000.f / (t1 - t0);
        t0 = t1;
        for (int i = 9; i >= 1; i--){
            fps_history[i] = fps_history[i - 1];
        }
        fps_history[0] = fps;
        if (fps_history[9] == 0.f){
            return 0;
        }
        for (int i = 0; i < 10; i++){
            avg_fps += fps_history[i];
        }
        avg_fps /= 10.f;
    }
    char text[32];
    sprintf(text, "time:%.2f fps=%.2f", t, avg_fps);
    int baseLine = 0;
    // 计算文本字符串宽度和高度
    cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, fontScale, tickness, &baseLine);
    // 设置到屏幕左上角
    int y = 0;
    int x = 0;
    // 设置到屏幕右上脚
//       int y = 0;
//    int x = rgb.cols - label_size.width;
    // 文字区域设置背景
    if(background){
        cv::rectangle(rgb, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)),color2, -1);
    }
    // 设置文字
    cv::putText(rgb, text, cv::Point(x, y + label_size.height),cv::FONT_HERSHEY_SIMPLEX, fontScale, color1);
    return 0;
}


// 绘制文字(自动检测是否超出屏幕,超出的话向屏幕里面移动)
int draw_text(cv::Mat& rgb, const char text[], bool background, int x, int y) {
    int baseLine = 0;
    // 计算文本大小
    cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, fontScale, tickness, &baseLine);
    // 检查文字是否超出屏幕边界,如果超出则调整位置
    if (x + label_size.width > rgb.cols) {
        x = rgb.cols - label_size.width - 5;  // 超过右边界,调整到右边界
    }
    if (y + label_size.height > rgb.rows) {
        y = rgb.rows - label_size.height - 5;  // 超过下边界,调整到下边界
    }
    if (x < 0) {
        x = 0;  // 超过左边界,调整到左边界
    }
    if (y < 0) {
        y = 0;  // 超过上边界,调整到上边界
    }
    // 绘制背景框
    if (background) {
        cv::rectangle(rgb, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)), color2, -1);
    }
    // 绘制文字
    cv::putText(rgb, text, cv::Point(x, y + label_size.height), cv::FONT_HERSHEY_SIMPLEX, fontScale, color1);
    return 0;
}

// 绘制矩形边框并检测是否超出屏幕边界
void draw_rect1(cv::Mat& rgb, int x1, int y1, int x2, int y2) {
    std::string rect_str = "x1: " + std::to_string(x1) + ", y1: " + std::to_string(y1) +", x2: " + std::to_string(x2) + ", y2: " + std::to_string(y2);
    LOG_INFO(rect_str.c_str());
    // 获取图像尺寸
    int img_width = rgb.cols;
    int img_height = rgb.rows;
    // 检测矩形是否超出屏幕边界并调整位置
    if (x1 < 0) x1 = 0;               // 左边界
    if (y1 < 0) y1 = 0;               // 上边界
    if (x2 > img_width) x2 = img_width; // 右边界
    if (y2 > img_height) y2 = img_height; // 下边界
    // 确保 x1, y1 在 x2, y2 之前(矩形必须合法)
    if (x1 > x2) std::swap(x1, x2);
    if (y1 > y2) std::swap(y1, y2);
    // 绘制矩形边框
    cv::rectangle(rgb, cv::Rect(x1, y1, x2 - x1, y2 - y1), color3, tickness);
}

// 绘制矩形边框并检测是否超出屏幕边界
void draw_rect2(cv::Mat& rgb, int x, int y, int w, int h) {
    // 获取图像尺寸
    int img_width = rgb.cols;
    int img_height = rgb.rows;

    // 将 (x, y, w, h) 转换为 (x1, y1, x2, y2)
    int x1 = x;
    int y1 = y;
    int x2 = x + w;
    int y2 = y + h;

    // 检测矩形是否超出屏幕边界并调整位置
    if (x1 < 0) x1 = 0;                    // 左边界
    if (y1 < 0) y1 = 0;                    // 上边界
    if (x2 > img_width) x2 = img_width;    // 右边界
    if (y2 > img_height) y2 = img_height;  // 下边界
    // 调用 draw_rect1 绘制矩形
    draw_rect1(rgb, x1, y1, x2, y2);
}


// 绘制关键点
void draw_point(cv::Mat& rgb, int x, int y){
    int width = rgb.cols;
    int height = rgb.rows;
    if (x < 0 || x >= width || y < 0 || y >= height) {
        return; // 如果超出区域,不绘制
    }

    // 在指定位置绘制一个红色的圆点,圆点半径为 3,线条宽度为 -1,表示实心圆
    cv::circle(rgb, cv::Point(x, y), 3, color4, -1);
}



// 模型初始化
static Ort::Env env{ORT_LOGGING_LEVEL_ERROR, "yolov5"};
static Ort::SessionOptions sessionOptions = Ort::SessionOptions();
static Ort::AllocatorWithDefaultOptions allocator;
static Ort::Session* ort_session = nullptr;

void module_init(const char *filePath) {
    std::string logMessage = std::string("ORT version: ") + Ort::GetVersionString();
    LOG_INFO(logMessage.c_str());
    // 创建 ort 会话
    ort_session = new Ort::Session(env, filePath, sessionOptions);
    // 打印输入信息
    LOG_INFO("Input:");
    for (int i = 0; i < ort_session->GetInputCount(); i++) {
        std::string name = ort_session->GetInputNameAllocated(i, allocator).get();
        std::vector<int64_t> shape = ort_session->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
        std::string shapeStr = "[";
        for (size_t j = 0; j < shape.size(); j++) {
            shapeStr += std::to_string(shape[j]);
            if (j < shape.size() - 1) {
                shapeStr += ", ";
            }
        }
        shapeStr += "]";
        std::string line = name.append(" => ").append(shapeStr);
        LOG_INFO(line.c_str());
    }
    // 打印输出信息
    LOG_INFO("Output:");
    for (int i = 0; i < ort_session->GetOutputCount(); i++) {
        std::string name = ort_session->GetOutputNameAllocated(i, allocator).get();
        std::vector<int64_t> shape = ort_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
        std::string shapeStr = "[";
        for (size_t j = 0; j < shape.size(); j++) {
            shapeStr += std::to_string(shape[j]);
            if (j < shape.size() - 1) {
                shapeStr += ", ";
            }
        }
        shapeStr += "]";
        std::string line = name.append(" => ").append(shapeStr);
        LOG_INFO(line.c_str());
    }
}


// 计算两个框的 IOU
float compute_iou(float x1, float y1, float x2, float y2, float x1_prime, float y1_prime, float x2_prime, float y2_prime) {
    // 计算交集的左上角和右下角坐标
    float inter_x1 = std::max(x1, x1_prime);
    float inter_y1 = std::max(y1, y1_prime);
    float inter_x2 = std::min(x2, x2_prime);
    float inter_y2 = std::min(y2, y2_prime);
    // 交集宽高
    float inter_width = std::max(0.0f, inter_x2 - inter_x1);
    float inter_height = std::max(0.0f, inter_y2 - inter_y1);
    float inter_area = inter_width * inter_height;
    // 并集宽高
    float box1_area = (x2 - x1) * (y2 - y1);
    float box2_area = (x2_prime - x1_prime) * (y2_prime - y1_prime);
    float union_area = box1_area + box2_area - inter_area;
    // 计算 IOU
    return inter_area / union_area;
}


// NMS 操作
void nms(std::vector<std::vector<float>>& boxes, float iou_threshold) {
    // 按照置信度进行排序,从高到低
    std::sort(boxes.begin(), boxes.end(), [](const std::vector<float>& a, const std::vector<float>& b) {
        return a[0] > b[0];  // 按照置信度排序
    });
    std::vector<bool> keep(boxes.size(), true);
    for (size_t i = 0; i < boxes.size(); ++i) {
        if (!keep[i]) continue;
        // 计算每个框的 IOU,并删除那些与当前框的 IOU 大于阈值的框
        for (size_t j = i + 1; j < boxes.size(); ++j) {
            if (!keep[j]) continue;
            // 取框的坐标
            float x1 = boxes[i][2];
            float y1 = boxes[i][3];
            float x2 = boxes[i][4];
            float y2 = boxes[i][5];
            float x1_prime = boxes[j][2];
            float y1_prime = boxes[j][3];
            float x2_prime = boxes[j][4];
            float y2_prime = boxes[j][5];
            // 计算 IOU
            float iou = compute_iou(x1, y1, x2, y2, x1_prime, y1_prime, x2_prime, y2_prime);
            if (iou > iou_threshold) {
                keep[j] = false;  // 删除与当前框重叠度高的框
            }
        }
    }
    // 保留的框
    std::vector<std::vector<float>> filtered_boxes;
    for (size_t i = 0; i < boxes.size(); ++i) {
        if (keep[i]) {
            filtered_boxes.push_back(boxes[i]);
        }
    }
    // 更新为过滤后的框
    boxes = filtered_boxes;
}



// 传入一帧进行推理
// Input => [1, 3, 640, 640]
// Output => [1, 25200, 16]

// 定义输入和输出的名称
const char* input_names[] = {"input"};
const char* output_names[] = {"output"};

void infer_frame(cv::Mat& rgb) {

    cv::Mat mat;
    cv::resize(rgb, mat, cv::Size(640, 640));  // Resize
    mat.convertTo(mat, CV_32F, 1.0 / 255);    // [0, 1]
//    cv::cvtColor(mat, mat, cv::COLOR_BGR2RGB); // RGB

    // hwc 转为 chw,如果是 opencv3 以上 dnn 模块可以直接转换:cv::dnn::blobFromImage 否则只能手动转换
    // 手动转换就是将图片按照通道切开成3个单通道矩阵,最后合并到一起变成 rrr ggg bbb 的顺序
    std::vector<cv::Mat> bgrChannels(3);
    cv::split(mat, bgrChannels);

    // 每个通道数据的首地址
    auto* b_addr = reinterpret_cast<float*>(bgrChannels[0].data);
    auto* g_addr = reinterpret_cast<float*>(bgrChannels[1].data);
    auto* r_addr = reinterpret_cast<float*>(bgrChannels[2].data);

    // 最后合并,获取每个通道的 float* 指针,直接复制内存
    int pixels = mat.total();
    std::vector<float> input(pixels * 3);
    memcpy(&input[pixels*0], r_addr, pixels * sizeof(float));
    memcpy(&input[pixels*1], g_addr, pixels * sizeof(float));
    memcpy(&input[pixels*2], b_addr, pixels * sizeof(float));

    // 输入尺寸和名称
    std::array<int64_t, 4> in_shape{ 1, 3, 640, 640};
    // 输出尺寸和名称
    std::array<int64_t, 3> out_shape{ 1, 25200, 16};

    auto memory = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);

    // 创建输入张量
    Ort::Value in_tensor = Ort::Value::CreateTensor<float>(
            memory,
            input.data(),
            input.size(),
            in_shape.data(),
            in_shape.size());

    // 保存输出的张量
    std::vector<float> output(1 * 25200 * 16);
    Ort::Value out_tensor = Ort::Value::CreateTensor<float>(
            memory,
            output.data(),
            output.size(),
            out_shape.data(),
            out_shape.size());

    // 推理
    ort_session -> Run(
            Ort::RunOptions{nullptr},
            input_names,
            &in_tensor,
            1,
            output_names,
            &out_tensor,
            1);


    // 保存 nms 前的box
    std::vector<std::vector<float>> before_nms;
    for (int i = 0; i < 25200; ++i) {
        // x y w h score  中心点坐标和分数
        // x y 关键点坐标
        // x y 关键点坐标
        // x y 关键点坐标
        // x y 关键点坐标
        // x y 关键点坐标
        // cls_conf 人脸置信度
        const float* data = &output[i * 16];
        // 前四个是框的坐标 (x, y, w, h)
        float x = data[0];
        float y = data[1];
        float w = data[2];
        float h = data[3];
        float x1 = x - w * 0.5f;
        float y1 = y - h * 0.5f;
        float x2 = x + w * 0.5f;
        float y2 = y + h * 0.5f;
        float score1 = data[4]; // 边框置信度
        float point1x = data[5];
        float point1y = data[6];
        float point2x = data[7];
        float point2y = data[8];
        float point3x = data[9];
        float point3y = data[10];
        float point4x = data[11];
        float point4y = data[12];
        float point5x = data[13];
        float point5y = data[14];
        float score2 = data[15];// 人脸置信度
        if(score1 >= 0.5 && score2>= 0.5){
            std::vector<float> box = {
                    score1, score2,
                    x1, y1, x2, y2,
                    point1x,point1y,
                    point2x,point2y,
                    point3x,point3y,
                    point4x,point4y,
                    point5x,point5y
            };
            before_nms.push_back(box);
        }
    }

    // nms 操作
    nms(before_nms, 0.4);

    // 缩放比例,640*640缩放到原始宽高
    float width = static_cast<float>(rgb.cols) / 640.0f;
    float height = static_cast<float>(rgb.rows) / 640.0f;

    // 标注
    for (const auto& box : before_nms) {
        float x1 = box[2] * width;
        float y1 = box[3] * height;
        float x2 = box[4] * width;
        float y2 = box[5] * height;
        float point1x = box[6] * width;
        float point1y = box[7] * height;
        float point2x = box[8] * width;
        float point2y = box[9] * height;
        float point3x = box[10] * width;
        float point3y = box[11] * height;
        float point4x = box[12] * width;
        float point4y = box[13] * height;
        float point5x = box[14] * width;
        float point5y = box[15] * height;
        draw_rect1(rgb,x1,y1,x2,y2);
        draw_point(rgb,point1x,point1y);
        draw_point(rgb,point2x,point2y);
        draw_point(rgb,point3x,point3y);
        draw_point(rgb,point4x,point4y);
        draw_point(rgb,point5x,point5y);
    }
}



项目详细代码:

https://github.com/TangYuFan/deeplearn-mobile

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

0x13

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值