一、思路
1.获取模型文件,获取他所需input参数,本文使用的yoloV8提供的demo模型,为(1,3,640,640),80个标签
2.从接口获取图片,转成模型所需input参数
3.把图片数据放入模型中运算,得到结果集
4.对结果集进行过滤和非极大值抑制
5.根据结果集对原图画框打标签,返回图片
二、获取模型,构建模型环境
需要添加依赖
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.16.0</version>
</dependency>
具体代码:
public class OnnxExecutor {
static{
// 加载opencv需要的动态库
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
}
/**
* 环境
*/
public OrtEnvironment env;
/**
* session
*/
public OrtSession session;
/**
* 模型参数名称
*/
public JSONObject names;
/**
*
*/
public long count;
/**
*
*/
public long channels;
/**
* 模型高(640)
*/
public long netHeight;
/**
* 模型宽(640)
*/
public long netWidth;
/**
* 图片高
*/
public float srcw;
/**
* 图片宽
*/
public float srch;
/**
* 图像原数据
*/
public Mat src;
/**
* 标注颜色
*/
public static Scalar color = new Scalar(0, 0, 255);
public static int tickness = 2;
/**
* 检测框筛选阈值,参考 detect.py 中的设置
*/
public float confThreshold = 0.75f;
/**
* 交集阈值
*/
public float nmsThreshold = 0.5f;
public OnnxExecutor(String modelPath, Float nmsThreshold) {
this(modelPath);
if (Objects.nonNull(nmsThreshold)) {
this.nmsThreshold = nmsThreshold;
}
}
// modelPath我使用的是绝对地址,System.getProperty("user.dir") + "/src/main/resources/static/best.onnx";
public OnnxExecutor(String modelPath) {
try (OrtEnvironment env = OrtEnvironment.getEnvironment()) {
this.env = env;
// 初始化session
session = env.createSession(modelPath, new OrtSession.SessionOptions());
OnnxModelMetadata metadata = session.getMetadata();
Map<String, NodeInfo> infoMap = session.getInputInfo();
TensorInfo nodeInfo = (TensorInfo)infoMap.get("images").getInfo();
String nameClass = metadata.getCustomMetadata().get("names");
// 标签数据
names = JSONObject.parseObject(nameClass.replace("\"","\"\""));
//1 模型每次处理一张图片
count = nodeInfo.getShape()[0];
//3 模型通道数
channels = nodeInfo.getShape()[1];
//640 模型高
netHeight = nodeInfo.getShape()[2];
//640 模型宽
netWidth = nodeInfo.getShape()[3];
System.out.println(names.get(0));
} catch (Exception e) {
log.error(e.getMessage());
throw new RuntimeException();
} finally {
}
}
}
三、把图片转换成对应参数格式
图片处理需要依赖OpenCv,我使用的4.8.0版本,从官网下载的jar包和dll。依稀记得4.5.0可以直接使用maven下载,没有试过
<dependency>
<groupId>org.opencv</groupId>
<artifactId>opencv</artifactId>
<version>4.8.0</version><scope>system</scope>
<systemPath>${project.basedir}/src/main/resources/lib/opencv/opencv-480.jar</systemPath>
</dependency>
加载动态库
static{
// 加载opencv需要的动态库
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
}
这个里有个坑,在使用动态库的时候,不能添加spring-boot-devtools依赖,会导致在图像转换的时候报找不到本地函数n_convertTo,具体原因不是很清楚。
图片转换及使用模型获取结果集
public String identifyImg(MultipartFile input){
//Image图片
try {
// 获取图片类型(PNG/JPRG)
Path path = Paths.get(input.getOriginalFilename());
// jpeg jpg bmp gif wbmp
String imgType = Files.probeContentType(path);
String[] split = imgType.split("\\/");
if (split.length > 1) {
imgType = split[1];
}
// 1.图片格式转换
BufferedImage image = ImageIO.read(input.getInputStream());
// new File();
srcw = image.getWidth();
srch = image.getHeight();
// 将BufferedImage转换为字节数组
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ImageIO.write(image, imgType, baos);
// mat:矩阵
src = new Mat(image.getHeight(), image.getWidth(), CvType.CV_8UC3);
src.put(0, 0, ((DataBufferByte) image.getRaster().getDataBuffer()).getData());
// 需处理图像数据
Mat dst = new Mat();
Imgproc.resize(src, dst, new Size(netWidth, netHeight));
// 2.图片转换为模型所需参数
Imgproc.cvtColor(dst, dst, Imgproc.COLOR_BGR2RGB);
// 归一化 0-255 转 0-1
dst.convertTo(dst, CvType.CV_32FC1, 1. / 255);
float[] whc = new float[ Long.valueOf(channels).intValue() * Long.valueOf(netWidth).intValue() * Long.valueOf(netHeight).intValue() ];
dst.get(0, 0, whc);
float[] chw = new float[src.length];
int j = 0;
for (int ch = 0; ch < 3; ++ch) {
for (int i = ch; i < src.length; i += 3) {
chw[j] = src[i];
j++;
}
}
OnnxTensor tensor = null;
try {
tensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(chw), new long[]{count,channels,netWidth,netHeight});
}
catch (Exception e){
e.printStackTrace();
System.exit(0);
}
// 执行
OrtSession.Result result = session.run(Collections.singletonMap("images", tensor));
System.out.println("res Data: "+result.get(0));
// 结果处理
OnnxTensor res = (OnnxTensor)result.get(0);
// [图片数量][84 = x+y+w+h+80个数据集类别][8400]
float[][][] dataRes = (float[][][])res.getValue();
// 获取第一张图片的结果
float[][] data = dataRes[0];
} catch (Exception e) {
throw new RuntimeException(e);
} finally {
}
return "ok";
}
四、对结果集进行处理
首先对于模型返回的数据需要有个大体的认识,模型的返回值是由模型本身决定的,例子中他的输出结构为[图片数量][坐标(x,y,w,h)加标签匹配度][8400(固定值,模型各尺度输出特征图叠加之后的结果)]
private Map<Object,Object> method2(float[][] data) {
// 将矩阵转置
// 先将xywh部分转置
float rawData[][]=new float[data[0].length][6];
System.out.println(data.length-1);
for(int i=0;i<4;i++){
for(int j=0;j<data[0].length;j++){
rawData[j][i]=data[i][j];
}
}
// 保存每个检查框置信值最高的类型置信值和该类型下标
for(int i=0;i<data[0].length;i++){
for(int j=4;j<data.length;j++){
if(rawData[i][4]<data[j][i]){
//置信值
rawData[i][4]=data[j][i];
//类型编号
rawData[i][5]=j-4;
}
}
}
List<ArrayList<Float>> boxes=new LinkedList<ArrayList<Float>>();
ArrayList<Float> box=null;
// 置信值过滤,xywh转xyxy
for(float[] d:rawData){
// 置信值过滤
if(d[4]>confThreshold){
// 给的坐标系是中间点的
// xywh(xy为中心点)转xyxy
d[0] = d[0] - d[2]/2;
d[1] = d[1] - d[3]/2;
d[2] = d[0] + d[2];
d[3] = d[1] + d[3];
// 置信值符合的进行插入法排序保存
box=new ArrayList<Float>();
for(float num:d) {
box.add(num);
}
if(boxes.size()==0){
boxes.add(box);
}else {
int i;
for(i=0;i<boxes.size();i++){
if(box.get(4)>boxes.get(i).get(4)){
boxes.add(i,box);
break;
}
}
// 插入到最后
if(i==boxes.size()){
boxes.add(box);
}
}
}
}
// 每个框分别有x1、x1、x2、y2、conf、class
//System.out.println(boxes);
// 非极大值抑制
int[] indexs=new int[boxes.size()];
Arrays.fill(indexs,1);
//用于标记1保留,0删除
for(int cur=0;cur<boxes.size();cur++){
if(indexs[cur]==0){
continue;
}
//当前框代表该类置信值最大的框
ArrayList<Float> curMaxConf=boxes.get(cur);
for(int i=cur+1;i<boxes.size();i++){
if(indexs[i]==0){
continue;
}
float classIndex=boxes.get(i).get(5);
// 两个检测框都检测到同一类数据,通过iou来判断是否检测到同一目标,这就是非极大值抑制
if(classIndex==curMaxConf.get(5)){
float x1=curMaxConf.get(0);
float y1=curMaxConf.get(1);
float x2=curMaxConf.get(2);
float y2=curMaxConf.get(3);
float x3=boxes.get(i).get(0);
float y3=boxes.get(i).get(1);
float x4=boxes.get(i).get(2);
float y4=boxes.get(i).get(3);
//将几种不相交的情况排除。提示:x1y1、x2y2、x3y3、x4y4对应两框的左上角和右下角
if(x1>x4||x2<x3||y1>y4||y2<y3){
continue;
}
// 两个矩形的交集面积
float intersectionWidth =Math.max(x1, x3) - Math.min(x2, x4);
float intersectionHeight=Math.max(y1, y3) - Math.min(y2, y4);
float intersectionArea =Math.max(0,intersectionWidth * intersectionHeight);
// 两个矩形的并集面积
float unionArea = (x2-x1)*(y2-y1)+(x4-x3)*(y4-y3)-intersectionArea;
// 计算IoU
float iou = intersectionArea / unionArea;
// 对交并比超过阈值的标记
indexs[i]=iou>nmsThreshold?0:1;
//System.out.println(cur+" "+i+" class"+curMaxConf.get(5)+" "+classIndex+" u:"+unionArea+" i:"+intersectionArea+" iou:"+ iou);
}
}
}
List<ArrayList<Float>> resBoxes=new LinkedList<ArrayList<Float>>();
for(int index=0;index<indexs.length;index++){
if(indexs[index]==1) {
resBoxes.add(boxes.get(index));
}
}
boxes=resBoxes;
float scaleW=srcw/netWidth;
float scaleH=srch/netHeight;
System.out.println("boxes.size : "+boxes.size());
for(ArrayList<Float> box1:boxes){
box1.set(0,box1.get(0)*scaleW);
box1.set(1,box1.get(1)*scaleH);
box1.set(2,box1.get(2)*scaleW);
box1.set(3,box1.get(3)*scaleH);
}
System.out.println("boxes: "+boxes);
//detect(boxes);
Map<Object,Object> map=new HashMap<Object,Object>();
map.put("boxes",boxes);
map.put("classNames",names);
return map;
}
处理之后的map就是图片上需要标记的框及识别的标签。
五、画框,输出结果
后面就是根据map数据,在原图上画框
public Mat showDetect(Map<Object,Object> map){
List<ArrayList<Float>> boxes=(List<ArrayList<Float>>)map.get("boxes");
JSONObject names=(JSONObject) map.get("classNames");
// 画框,加数据
for(ArrayList<Float> box:boxes){
float x1=box.get(0);
float y1=box.get(1);
float x2=box.get(2);
float y2=box.get(3);
float config=box.get(4);
String className=(String)names.get((int)box.get(5).intValue());;
Point point1=new Point(x1,y1);
Point point2=new Point(x2,y2);
Imgproc.rectangle(src, point1, point2 , color,2);
String conf=new DecimalFormat("#.###").format(config);
Imgproc.putText(src,className+" "+conf,new Point(x1,y1-5),0,0.5,new Scalar(255,0,0),1);
}
return src;
}
我在这里是直接把mat转成流保存到本地了,各位就拿着mat放到response的outputstream里扔出去。
效果参考
狗头保命:图片来自于百度图片
结尾
这里只是做一个demo学习用,其中看不懂的我也不懂,各位自行百度。
参考文章:ONNX格式模型 学习笔记 (onnxRuntime部署)---用java调用yolov8模型来举例_yolov8 java-优快云博客