下面是使用 opencv-camera,实时处理区域内人脸检测 android 推理 demo。
首先是整合 opcv-camera 进去:
为了方便直接将整个 opencv-android-sdk 全部导入:
然后在原来的项目模块app中添加 opencv的 java 相关依赖,主要添加红色两行:
app/build.grandle
dependencies {
implementation fileTree(dir: 'libs', include: ['*.jar'])
implementation 'androidx.appcompat:appcompat:1.4.1'
implementation 'com.google.android.material:material:1.5.0'
implementation 'androidx.constraintlayout:constraintlayout:2.1.3'
testImplementation 'junit:junit:4.13.2'
androidTestImplementation 'androidx.test.ext:junit:1.1.3'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0'
implementation project(':opencvsdk')
}
最后在项目中要使用opencv的地方加载jni库,可以添加到 MainActivity 中:
System.loadLibrary("opencv_java4"); 或者 OpenCVLoader.initDebug();
要使用 opencv-camera,MainActivity 继承 CameraActivity,然后在回调函数中获取每一帧进行处理,比如下面对每一帧添加识别区域边框:
// 获取每一帧回调数据
private CameraBridgeViewBase.CvCameraViewListener2 cameraViewListener2 = new CameraBridgeViewBase.CvCameraViewListener2() {
@Override
public void onCameraViewStarted(int width, int height) {
System.out.println("开始预览 width="+width+",height="+height);
// 预览界面是 640*480,模型输入时 320*320,计算识别区域坐标
int detection_x1 = (640 - OnnxUtil.w)/2;
int detection_x2 = (640 - OnnxUtil.w)/2 + OnnxUtil.w;
int detection_y1 = (480 - OnnxUtil.h)/2;
int detection_y2 = (480 - OnnxUtil.h)/2 + OnnxUtil.h;;
System.out.println("识别区域:"+"("+detection_x1+","+detection_y1+")"+"("+detection_x2+","+detection_y2+")");
// 缓存识别区域两个点
detection_p1 = new Point(detection_x1,detection_y1);
detection_p2 = new Point(detection_x2,detection_y2);
detection_box_color = new Scalar(255, 0, 0);
detection_box_tickness = 2;
}
@Override
public void onCameraViewStopped() {}
@Override
public Mat onCameraFrame(CameraBridgeViewBase.CvCameraViewFrame frame) {
// 获取 cv::Mat
Mat mat = frame.rgba();
// 标注识别区域
Imgproc.rectangle(mat, detection_p1, detection_p2,detection_box_color,detection_box_tickness);
return mat;
}
};
在界面中开启预览:
ui资源:
<org.opencv.android.JavaCamera2View
android:id="@+id/camera_view"
app:layout_constraintTop_toTopOf="parent"
app:layout_constraintLeft_toLeftOf="parent"
android:layout_width="match_parent"
android:layout_height="match_parent">
</org.opencv.android.JavaCamera2View>
java开启预览:
private BaseLoaderCallback baseLoaderCallback = new BaseLoaderCallback(this) {
@Override
public void onManagerConnected(int status) {
switch (status) {
case LoaderCallbackInterface.SUCCESS: {
if (camera2View != null) {
// 设置前置还是后置摄像头 0后置 1前置
camera2View.setCameraIndex(cameraId);
// 注册每一帧回调
camera2View.setCvCameraViewListener(cameraViewListener2);
// 显示/关闭 帧率 disableFpsMeter/enableFpsMeter
// 要修改字体和颜色直接修改 FpsMeter 类即可
camera2View.enableFpsMeter();
// 设置视图宽高和模型一致减少resize操作,模型输入一般尺寸不大,这样相机渲染fps会更高
camera2View.setMaxFrameSize(win_w,win_h);
// 开启
camera2View.enableView();
}
}
break;
default:
super.onManagerConnected(status);
break;
}
}
};
下面是全部推理 MainActivity 代码:
package com.example.camera_opencv;
import android.content.pm.ActivityInfo;
import android.os.Bundle;
import android.view.WindowManager;
import com.example.camera_opencv.databinding.ActivityMainBinding;
import org.opencv.android.*;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Scalar;
import org.opencv.imgproc.Imgproc;
import java.util.Arrays;
import java.util.List;
public class MainActivity extends CameraActivity{
// 动态库
static {
// 我们自己的jni
System.loadLibrary("camera_opencv");
// 新加的 opencv 的jni
System.loadLibrary("opencv_java4");
}
private ActivityMainBinding binding;
// 预览界面
private JavaCamera2View camera2View;
// 相机编号 0后置 1前置
private int cameraId = 1;
// 设置预览界面宽高,在次宽高基础上限制识别区域
private int win_w = 640;
private int win_h = 480;
// 识别区域两个点
private Point detection_p1;
private Point detection_p2;
private Scalar detection_box_color;
private int detection_box_tickness;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
binding = ActivityMainBinding.inflate(getLayoutInflater());
setContentView(binding.getRoot());
// 加载模型
OnnxUtil.loadModule(getAssets());
// 强制横屏
setRequestedOrientation(ActivityInfo.SCREEN_ORIENTATION_LANDSCAPE);
// 隐藏上方状态栏
getWindow().setFlags(WindowManager.LayoutParams.FLAG_FULLSCREEN, WindowManager.LayoutParams.FLAG_FULLSCREEN);
// 预览界面
camera2View = findViewById(R.id.camera_view);
}
@Override
protected List<? extends CameraBridgeViewBase> getCameraViewList() {
return Arrays.asList(camera2View);
}
@Override
public void onPause() {
super.onPause();
if (camera2View != null) {
// 关闭预览
camera2View.disableView();
}
}
@Override
public void onResume() {
super.onResume();
if (OpenCVLoader.initDebug()) {
baseLoaderCallback.onManagerConnected(LoaderCallbackInterface.SUCCESS);
} else {
OpenCVLoader.initAsync(OpenCVLoader.OPENCV_VERSION, this, baseLoaderCallback);
}
}
// 获取每一帧回调数据
private CameraBridgeViewBase.CvCameraViewListener2 cameraViewListener2 = new CameraBridgeViewBase.CvCameraViewListener2() {
@Override
public void onCameraViewStarted(int width, int height) {
System.out.println("开始预览 width="+width+",height="+height);
// 预览界面是 640*480,模型输入时 320*320,计算识别区域坐标
int detection_x1 = (640 - OnnxUtil.w)/2;
int detection_x2 = (640 - OnnxUtil.w)/2 + OnnxUtil.w;
int detection_y1 = (480 - OnnxUtil.h)/2;
int detection_y2 = (480 - OnnxUtil.h)/2 + OnnxUtil.h;;
System.out.println("识别区域:"+"("+detection_x1+","+detection_y1+")"+"("+detection_x2+","+detection_y2+")");
// 缓存识别区域两个点
detection_p1 = new Point(detection_x1,detection_y1);
detection_p2 = new Point(detection_x2,detection_y2);
detection_box_color = new Scalar(255, 0, 0);
detection_box_tickness = 2;
}
@Override
public void onCameraViewStopped() {}
@Override
public Mat onCameraFrame(CameraBridgeViewBase.CvCameraViewFrame frame) {
// 获取 cv::Mat
Mat mat = frame.rgba();
// 标注识别区域
Imgproc.rectangle(mat, detection_p1, detection_p2,detection_box_color,detection_box_tickness);
// 推理并标注
OnnxUtil.inference(mat,detection_p1,detection_p2);
return mat;
}
};
// 开启预览
private BaseLoaderCallback baseLoaderCallback = new BaseLoaderCallback(this) {
@Override
public void onManagerConnected(int status) {
switch (status) {
case LoaderCallbackInterface.SUCCESS: {
if (camera2View != null) {
// 设置前置还是后置摄像头 0后置 1前置
camera2View.setCameraIndex(cameraId);
// 注册每一帧回调
camera2View.setCvCameraViewListener(cameraViewListener2);
// 显示/关闭 帧率 disableFpsMeter/enableFpsMeter
// 要修改字体和颜色直接修改 FpsMeter 类即可
camera2View.enableFpsMeter();
// 设置视图宽高和模型一致减少resize操作,模型输入一般尺寸不大,这样相机渲染fps会更高
camera2View.setMaxFrameSize(win_w,win_h);
// 开启
camera2View.enableView();
}
}
break;
default:
super.onManagerConnected(status);
break;
}
}
};
}
onnx 模型加载和推理代码:
使用的微软onnx推理框架:
implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'
implementation 'com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.release'
package com.example.camera_opencv;
import ai.onnxruntime.*;
import android.content.res.AssetManager;
import org.opencv.core.*;
import org.opencv.dnn.Dnn;
import org.opencv.imgproc.Imgproc;
import java.io.ByteArrayOutputStream;
import java.io.InputStream;
import java.nio.FloatBuffer;
import java.util.*;
public class OnnxUtil {
// onnxruntime 环境
public static OrtEnvironment env;
public static OrtSession session;
// 模型输入
public static int w = 0;
public static int h = 0;
public static int c = 3;
// 标注颜色
public static Scalar green = new Scalar(0, 255, 0);
public static int tickness = 2;
// 模型加载
public static void loadModule(AssetManager assetManager){
// 下面包含了多个模型
// yolov5face-blazeface-640x640.onnx 3.4Mb
// yolov5face-l-640x640.onnx 181Mb
// yolov5face-m-640x640.onnx 83Mb
// yolov5face-n-0.5-320x320.onnx 2.5Mb
// yolov5face-n-0.5-640x640.onnx 4.6Mb
// yolov5face-n-640x640.onnx 9.5Mb
// yolov5face-s-640x640.onnx 30Mb
w = 320;
h = 320;
c = 3;
try {
// 模型输入: input -> [1, 3, 320, 320] -> FLOAT
// 模型输出: output -> [1, 6300, 16] -> FLOAT
InputStream inputStream = assetManager.open("yolov5face-n-0.5-320x320.onnx");
ByteArrayOutputStream buffer = new ByteArrayOutputStream();
int nRead;
byte[] data = new byte[1024];
while ((nRead = inputStream.read(data, 0, data.length)) != -1) {
buffer.write(data, 0, nRead);
}
buffer.flush();
byte[] module = buffer.toByteArray();
System.out.println("开始加载模型");
env = OrtEnvironment.getEnvironment();
session = env.createSession(module, new OrtSession.SessionOptions());
session.getInputInfo().entrySet().stream().forEach(n -> {
String inputName = n.getKey();
NodeInfo inputInfo = n.getValue();
long[] shape = ((TensorInfo) inputInfo.getInfo()).getShape();
String javaType = ((TensorInfo) inputInfo.getInfo()).type.toString();
System.out.println("模型输入: "+inputName + " -> " + Arrays.toString(shape) + " -> " + javaType);
});
session.getOutputInfo().entrySet().stream().forEach(n -> {
String outputName = n.getKey();
NodeInfo outputInfo = n.getValue();
long[] shape = ((TensorInfo) outputInfo.getInfo()).getShape();
String javaType = ((TensorInfo) outputInfo.getInfo()).type.toString();
System.out.println("模型输出: "+outputName + " -> " + Arrays.toString(shape) + " -> " + javaType);
});
} catch (Exception e) {
e.printStackTrace();
}
}
// 模型推理,输入原始图片和识别区域两个点
public static void inference(Mat mat,Point detection_p1,Point detection_p2){
int px = Double.valueOf(detection_p1.x).intValue();
int py = Double.valueOf(detection_p1.y).intValue();
// 提取rgb(chw存储)并做归一化,也就是 rrrrr bbbbb ggggg
float[] chw = new float[c*h*w];
// 像素点索引
int index = 0;
for(int j=0 ; j<h ; j++){
for(int i=0 ; i<w ; i++){
// 第j行,第i列,根据识别区域p1得到xy坐标的偏移,直接加就行
double[] rgb = mat.get(j+py,i+px);
// 缓存到 chw 中,mat 是 rgba 数据对应的下标 2103
chw[index] = (float)(rgb[2]/255);//r
chw[index + w * h * 1 ] = (float)(rgb[1]/255);//G
chw[index + w * h * 2 ] = (float)(rgb[0]/255);//b
index ++;
}
}
// 创建张量并进行推理
try {
OnnxTensor tensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(chw), new long[]{1,c,h,w});
OrtSession.Result output = session.run(Collections.singletonMap("input", tensor));
float[][] out = ((float[][][])(output.get(0)).getValue())[0];
ArrayList<float[]> datas = new ArrayList<>();
for(int i=0;i<out.length;i++){
float[] data = out[i];
float score1 = data[4]; // 边框置信度
float score2 = data[15];// 人脸置信度
if( score1 >= 0.2 && score2>= 0.2){
// xywh 转 x1y1x2y2
float xx = data[0];
float yy = data[1];
float ww = data[2];
float hh = data[3];
float[] xyxy = xywh2xyxy(new float[]{xx,yy,ww,hh},w,h);
data[0] = xyxy[0];
data[1] = xyxy[1];
data[2] = xyxy[2];
data[3] = xyxy[3];
datas.add(data);
}
}
// nms
ArrayList<float[]> datas_after_nms = new ArrayList<>();
while (!datas.isEmpty()){
float[] max = datas.get(0);
datas_after_nms.add(max);
Iterator<float[]> it = datas.iterator();
while (it.hasNext()) {
// nsm阈值
float[] obj = it.next();
double iou = calculateIoU(max,obj);
if (iou > 0.5f) {
it.remove();
}
}
}
// 标注
datas_after_nms.stream().forEach(n->{
// x y w h score 中心点坐标和分数
// x y 关键点坐标
// x y 关键点坐标
// x y 关键点坐标
// x y 关键点坐标
// x y 关键点坐标
// cls_conf 人脸置信度
// 画边框和关键点需要添加偏移
int x1 = Float.valueOf(n[0]).intValue() + px;
int y1 = Float.valueOf(n[1]).intValue() + py;
int x2 = Float.valueOf(n[2]).intValue() + px;
int y2 = Float.valueOf(n[3]).intValue() + py;
Imgproc.rectangle(mat, new Point(x1, y1), new Point(x2, y2), green, tickness);
float point1_x = Float.valueOf(n[5]).intValue() + px;// 关键点1
float point1_y = Float.valueOf(n[6]).intValue() + py;//
float point2_x = Float.valueOf(n[7]).intValue() + px;// 关键点2
float point2_y = Float.valueOf(n[8]).intValue() + py;//
float point3_x = Float.valueOf(n[9]).intValue() + px;// 关键点3
float point3_y = Float.valueOf(n[10]).intValue() + py;//
float point4_x = Float.valueOf(n[11]).intValue() + px;// 关键点4
float point4_y = Float.valueOf(n[12]).intValue() + py;//
float point5_x = Float.valueOf(n[13]).intValue() + px;// 关键点5
float point5_y = Float.valueOf(n[14]).intValue() + py;//
Imgproc.circle(mat, new Point(point1_x, point1_y), 1, green, tickness);
Imgproc.circle(mat, new Point(point2_x, point2_y), 1, green, tickness);
Imgproc.circle(mat, new Point(point3_x, point3_y), 1, green, tickness);
Imgproc.circle(mat, new Point(point4_x, point4_y), 1, green, tickness);
Imgproc.circle(mat, new Point(point5_x, point5_y), 1, green, tickness);
});
}
catch (Exception e){
e.printStackTrace();
}
}
// 中心点坐标转 xin xmax ymin ymax
public static float[] xywh2xyxy(float[] bbox,float maxWidth,float maxHeight) {
// 中心点坐标
float x = bbox[0];
float y = bbox[1];
float w = bbox[2];
float h = bbox[3];
// 计算
float x1 = x - w * 0.5f;
float y1 = y - h * 0.5f;
float x2 = x + w * 0.5f;
float y2 = y + h * 0.5f;
// 限制在图片区域内
return new float[]{
x1 < 0 ? 0 : x1,
y1 < 0 ? 0 : y1,
x2 > maxWidth ? maxWidth:x2,
y2 > maxHeight? maxHeight:y2};
}
// 计算两个框的交并比
private static double calculateIoU(float[] box1, float[] box2) {
// getXYXY() 返回 xmin-0 ymin-1 xmax-2 ymax-3
double x1 = Math.max(box1[0], box2[0]);
double y1 = Math.max(box1[1], box2[1]);
double x2 = Math.min(box1[2], box2[2]);
double y2 = Math.min(box1[3], box2[3]);
double intersectionArea = Math.max(0, x2 - x1 + 1) * Math.max(0, y2 - y1 + 1);
double box1Area = (box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1);
double box2Area = (box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1);
double unionArea = box1Area + box2Area - intersectionArea;
return intersectionArea / unionArea;
}
}
如果是使用 native 开发的话,cpp 推理代码如下:
//
// Created by tyf on 2023/11/7.
//
#include <benchmark.h>
#include "infer.h"
#include "opencv2/core/core.hpp"
#include "opencv2/imgproc/imgproc.hpp"
#include <ctime>
#include <onnxruntime_cxx_api.h>
#include <android/log.h>
#define LOG_TAG "com.tyf.demo"
#define LOG_INFO(msg) __android_log_print(ANDROID_LOG_INFO, LOG_TAG, "%s", msg)
// 文字和背景size、颜色
float fontScale = 0.35;
int tickness = 2;
cv::Scalar color1 = cv::Scalar(255,255,255);
cv::Scalar color2 = cv::Scalar(0,0,0);
cv::Scalar color3 = cv::Scalar(255,0,0);
cv::Scalar color4 = cv::Scalar(255,0,0);
// 绘制 fps
int draw_fps(cv::Mat& rgb, double t1,bool background){
double t = t1;
// resolve moving average
float avg_fps = 0.f;
{
static double t0 = 0.f;
static float fps_history[10] = {0.f};
double t1 = ncnn::get_current_time();
if (t0 == 0.f){
t0 = t1;
return 0;
}
float fps = 1000.f / (t1 - t0);
t0 = t1;
for (int i = 9; i >= 1; i--){
fps_history[i] = fps_history[i - 1];
}
fps_history[0] = fps;
if (fps_history[9] == 0.f){
return 0;
}
for (int i = 0; i < 10; i++){
avg_fps += fps_history[i];
}
avg_fps /= 10.f;
}
char text[32];
sprintf(text, "time:%.2f fps=%.2f", t, avg_fps);
int baseLine = 0;
// 计算文本字符串宽度和高度
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, fontScale, tickness, &baseLine);
// 设置到屏幕左上角
int y = 0;
int x = 0;
// 设置到屏幕右上脚
// int y = 0;
// int x = rgb.cols - label_size.width;
// 文字区域设置背景
if(background){
cv::rectangle(rgb, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)),color2, -1);
}
// 设置文字
cv::putText(rgb, text, cv::Point(x, y + label_size.height),cv::FONT_HERSHEY_SIMPLEX, fontScale, color1);
return 0;
}
// 绘制文字(自动检测是否超出屏幕,超出的话向屏幕里面移动)
int draw_text(cv::Mat& rgb, const char text[], bool background, int x, int y) {
int baseLine = 0;
// 计算文本大小
cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, fontScale, tickness, &baseLine);
// 检查文字是否超出屏幕边界,如果超出则调整位置
if (x + label_size.width > rgb.cols) {
x = rgb.cols - label_size.width - 5; // 超过右边界,调整到右边界
}
if (y + label_size.height > rgb.rows) {
y = rgb.rows - label_size.height - 5; // 超过下边界,调整到下边界
}
if (x < 0) {
x = 0; // 超过左边界,调整到左边界
}
if (y < 0) {
y = 0; // 超过上边界,调整到上边界
}
// 绘制背景框
if (background) {
cv::rectangle(rgb, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)), color2, -1);
}
// 绘制文字
cv::putText(rgb, text, cv::Point(x, y + label_size.height), cv::FONT_HERSHEY_SIMPLEX, fontScale, color1);
return 0;
}
// 绘制矩形边框并检测是否超出屏幕边界
void draw_rect1(cv::Mat& rgb, int x1, int y1, int x2, int y2) {
std::string rect_str = "x1: " + std::to_string(x1) + ", y1: " + std::to_string(y1) +", x2: " + std::to_string(x2) + ", y2: " + std::to_string(y2);
LOG_INFO(rect_str.c_str());
// 获取图像尺寸
int img_width = rgb.cols;
int img_height = rgb.rows;
// 检测矩形是否超出屏幕边界并调整位置
if (x1 < 0) x1 = 0; // 左边界
if (y1 < 0) y1 = 0; // 上边界
if (x2 > img_width) x2 = img_width; // 右边界
if (y2 > img_height) y2 = img_height; // 下边界
// 确保 x1, y1 在 x2, y2 之前(矩形必须合法)
if (x1 > x2) std::swap(x1, x2);
if (y1 > y2) std::swap(y1, y2);
// 绘制矩形边框
cv::rectangle(rgb, cv::Rect(x1, y1, x2 - x1, y2 - y1), color3, tickness);
}
// 绘制矩形边框并检测是否超出屏幕边界
void draw_rect2(cv::Mat& rgb, int x, int y, int w, int h) {
// 获取图像尺寸
int img_width = rgb.cols;
int img_height = rgb.rows;
// 将 (x, y, w, h) 转换为 (x1, y1, x2, y2)
int x1 = x;
int y1 = y;
int x2 = x + w;
int y2 = y + h;
// 检测矩形是否超出屏幕边界并调整位置
if (x1 < 0) x1 = 0; // 左边界
if (y1 < 0) y1 = 0; // 上边界
if (x2 > img_width) x2 = img_width; // 右边界
if (y2 > img_height) y2 = img_height; // 下边界
// 调用 draw_rect1 绘制矩形
draw_rect1(rgb, x1, y1, x2, y2);
}
// 绘制关键点
void draw_point(cv::Mat& rgb, int x, int y){
int width = rgb.cols;
int height = rgb.rows;
if (x < 0 || x >= width || y < 0 || y >= height) {
return; // 如果超出区域,不绘制
}
// 在指定位置绘制一个红色的圆点,圆点半径为 3,线条宽度为 -1,表示实心圆
cv::circle(rgb, cv::Point(x, y), 3, color4, -1);
}
// 模型初始化
static Ort::Env env{ORT_LOGGING_LEVEL_ERROR, "yolov5"};
static Ort::SessionOptions sessionOptions = Ort::SessionOptions();
static Ort::AllocatorWithDefaultOptions allocator;
static Ort::Session* ort_session = nullptr;
void module_init(const char *filePath) {
std::string logMessage = std::string("ORT version: ") + Ort::GetVersionString();
LOG_INFO(logMessage.c_str());
// 创建 ort 会话
ort_session = new Ort::Session(env, filePath, sessionOptions);
// 打印输入信息
LOG_INFO("Input:");
for (int i = 0; i < ort_session->GetInputCount(); i++) {
std::string name = ort_session->GetInputNameAllocated(i, allocator).get();
std::vector<int64_t> shape = ort_session->GetInputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
std::string shapeStr = "[";
for (size_t j = 0; j < shape.size(); j++) {
shapeStr += std::to_string(shape[j]);
if (j < shape.size() - 1) {
shapeStr += ", ";
}
}
shapeStr += "]";
std::string line = name.append(" => ").append(shapeStr);
LOG_INFO(line.c_str());
}
// 打印输出信息
LOG_INFO("Output:");
for (int i = 0; i < ort_session->GetOutputCount(); i++) {
std::string name = ort_session->GetOutputNameAllocated(i, allocator).get();
std::vector<int64_t> shape = ort_session->GetOutputTypeInfo(i).GetTensorTypeAndShapeInfo().GetShape();
std::string shapeStr = "[";
for (size_t j = 0; j < shape.size(); j++) {
shapeStr += std::to_string(shape[j]);
if (j < shape.size() - 1) {
shapeStr += ", ";
}
}
shapeStr += "]";
std::string line = name.append(" => ").append(shapeStr);
LOG_INFO(line.c_str());
}
}
// 计算两个框的 IOU
float compute_iou(float x1, float y1, float x2, float y2, float x1_prime, float y1_prime, float x2_prime, float y2_prime) {
// 计算交集的左上角和右下角坐标
float inter_x1 = std::max(x1, x1_prime);
float inter_y1 = std::max(y1, y1_prime);
float inter_x2 = std::min(x2, x2_prime);
float inter_y2 = std::min(y2, y2_prime);
// 交集宽高
float inter_width = std::max(0.0f, inter_x2 - inter_x1);
float inter_height = std::max(0.0f, inter_y2 - inter_y1);
float inter_area = inter_width * inter_height;
// 并集宽高
float box1_area = (x2 - x1) * (y2 - y1);
float box2_area = (x2_prime - x1_prime) * (y2_prime - y1_prime);
float union_area = box1_area + box2_area - inter_area;
// 计算 IOU
return inter_area / union_area;
}
// NMS 操作
void nms(std::vector<std::vector<float>>& boxes, float iou_threshold) {
// 按照置信度进行排序,从高到低
std::sort(boxes.begin(), boxes.end(), [](const std::vector<float>& a, const std::vector<float>& b) {
return a[0] > b[0]; // 按照置信度排序
});
std::vector<bool> keep(boxes.size(), true);
for (size_t i = 0; i < boxes.size(); ++i) {
if (!keep[i]) continue;
// 计算每个框的 IOU,并删除那些与当前框的 IOU 大于阈值的框
for (size_t j = i + 1; j < boxes.size(); ++j) {
if (!keep[j]) continue;
// 取框的坐标
float x1 = boxes[i][2];
float y1 = boxes[i][3];
float x2 = boxes[i][4];
float y2 = boxes[i][5];
float x1_prime = boxes[j][2];
float y1_prime = boxes[j][3];
float x2_prime = boxes[j][4];
float y2_prime = boxes[j][5];
// 计算 IOU
float iou = compute_iou(x1, y1, x2, y2, x1_prime, y1_prime, x2_prime, y2_prime);
if (iou > iou_threshold) {
keep[j] = false; // 删除与当前框重叠度高的框
}
}
}
// 保留的框
std::vector<std::vector<float>> filtered_boxes;
for (size_t i = 0; i < boxes.size(); ++i) {
if (keep[i]) {
filtered_boxes.push_back(boxes[i]);
}
}
// 更新为过滤后的框
boxes = filtered_boxes;
}
// 传入一帧进行推理
// Input => [1, 3, 640, 640]
// Output => [1, 25200, 16]
// 定义输入和输出的名称
const char* input_names[] = {"input"};
const char* output_names[] = {"output"};
void infer_frame(cv::Mat& rgb) {
cv::Mat mat;
cv::resize(rgb, mat, cv::Size(640, 640)); // Resize
mat.convertTo(mat, CV_32F, 1.0 / 255); // [0, 1]
// cv::cvtColor(mat, mat, cv::COLOR_BGR2RGB); // RGB
// hwc 转为 chw,如果是 opencv3 以上 dnn 模块可以直接转换:cv::dnn::blobFromImage 否则只能手动转换
// 手动转换就是将图片按照通道切开成3个单通道矩阵,最后合并到一起变成 rrr ggg bbb 的顺序
std::vector<cv::Mat> bgrChannels(3);
cv::split(mat, bgrChannels);
// 每个通道数据的首地址
auto* b_addr = reinterpret_cast<float*>(bgrChannels[0].data);
auto* g_addr = reinterpret_cast<float*>(bgrChannels[1].data);
auto* r_addr = reinterpret_cast<float*>(bgrChannels[2].data);
// 最后合并,获取每个通道的 float* 指针,直接复制内存
int pixels = mat.total();
std::vector<float> input(pixels * 3);
memcpy(&input[pixels*0], r_addr, pixels * sizeof(float));
memcpy(&input[pixels*1], g_addr, pixels * sizeof(float));
memcpy(&input[pixels*2], b_addr, pixels * sizeof(float));
// 输入尺寸和名称
std::array<int64_t, 4> in_shape{ 1, 3, 640, 640};
// 输出尺寸和名称
std::array<int64_t, 3> out_shape{ 1, 25200, 16};
auto memory = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
// 创建输入张量
Ort::Value in_tensor = Ort::Value::CreateTensor<float>(
memory,
input.data(),
input.size(),
in_shape.data(),
in_shape.size());
// 保存输出的张量
std::vector<float> output(1 * 25200 * 16);
Ort::Value out_tensor = Ort::Value::CreateTensor<float>(
memory,
output.data(),
output.size(),
out_shape.data(),
out_shape.size());
// 推理
ort_session -> Run(
Ort::RunOptions{nullptr},
input_names,
&in_tensor,
1,
output_names,
&out_tensor,
1);
// 保存 nms 前的box
std::vector<std::vector<float>> before_nms;
for (int i = 0; i < 25200; ++i) {
// x y w h score 中心点坐标和分数
// x y 关键点坐标
// x y 关键点坐标
// x y 关键点坐标
// x y 关键点坐标
// x y 关键点坐标
// cls_conf 人脸置信度
const float* data = &output[i * 16];
// 前四个是框的坐标 (x, y, w, h)
float x = data[0];
float y = data[1];
float w = data[2];
float h = data[3];
float x1 = x - w * 0.5f;
float y1 = y - h * 0.5f;
float x2 = x + w * 0.5f;
float y2 = y + h * 0.5f;
float score1 = data[4]; // 边框置信度
float point1x = data[5];
float point1y = data[6];
float point2x = data[7];
float point2y = data[8];
float point3x = data[9];
float point3y = data[10];
float point4x = data[11];
float point4y = data[12];
float point5x = data[13];
float point5y = data[14];
float score2 = data[15];// 人脸置信度
if(score1 >= 0.5 && score2>= 0.5){
std::vector<float> box = {
score1, score2,
x1, y1, x2, y2,
point1x,point1y,
point2x,point2y,
point3x,point3y,
point4x,point4y,
point5x,point5y
};
before_nms.push_back(box);
}
}
// nms 操作
nms(before_nms, 0.4);
// 缩放比例,640*640缩放到原始宽高
float width = static_cast<float>(rgb.cols) / 640.0f;
float height = static_cast<float>(rgb.rows) / 640.0f;
// 标注
for (const auto& box : before_nms) {
float x1 = box[2] * width;
float y1 = box[3] * height;
float x2 = box[4] * width;
float y2 = box[5] * height;
float point1x = box[6] * width;
float point1y = box[7] * height;
float point2x = box[8] * width;
float point2y = box[9] * height;
float point3x = box[10] * width;
float point3y = box[11] * height;
float point4x = box[12] * width;
float point4y = box[13] * height;
float point5x = box[14] * width;
float point5y = box[15] * height;
draw_rect1(rgb,x1,y1,x2,y2);
draw_point(rgb,point1x,point1y);
draw_point(rgb,point2x,point2y);
draw_point(rgb,point3x,point3y);
draw_point(rgb,point4x,point4y);
draw_point(rgb,point5x,point5y);
}
}
项目详细代码: