【深度学习】python实现NMS (非极大值抑制)
python实现NMS
首先需要明确一点,在多类别的目标检测任务中,NMS是发生在同一类别中的。
NMS的流程:
# 以YOLO系列的目标检测算法为例, 网络输出tensor结构如下:
# [center_x, center_y, width, height, confidence, class1_score, class2_score, class3_score...]
# 假设一共有三个类别, 则输入tensor结构如下:
# boxes = np.array([center_x, center_y, width, height, confidence, class1_score, class2_score, class3_score],
# [center_x, center_y, width, height, confidence, class1_score, class2_score, class3_score],
# ...
# [center_x, center_y, width, height, confidence, class1_score, class2_score, class3_score])
# 用于保留nms的结果box
result = []
# 对每个类别的输出分开处理
for each in range(类别数量):
# 得到当前类别的所有输出
the_boxes = boxes[np.where(boxes[:, 5:8].argsort()[:,-1] == each)[0].tolist(), :]
center_x = the_boxes[:, 0]
center_y = the_boxes[:, 1]
width = the_boxes[:, 2]
height = the_boxes[:, 3]
confidence = the_boxes[:, 4]
# 置信度从大到小排序
index = confidecxe.argsort()[::-1]
# 用于当前类别保留nms的结果box
keep = []
# 计算置信度最大box和其余所有box的IOU,大于阈值的则从index中剔除,保留当前置信度最大的box
# 在index中剔除刚才保留的置信度最大的box,重复上述过程,直到index为空
# 所有保留下来的box就是nms后的结果
while index.size > 0:
best = index[0]
keep.append(the_boxes[best, :])
# 函数get_iou用于计算置信度最大的box和其余所有box的IOU
ious = get_iou(best, center_x, center_y, width, height)
# thresh是nms的IOU阈值
idx = np.where(ious <= thresh)[0]
# 更新index,因为计算idx时,去除了原始index中最大的值,所以这里更新idx时要加1
index = index[idx + 1]
result.append(keep)
python实现:
# -*- coding:utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
# 假设这个numpy是yolo网路的输出
boxes = np.array([[155, 155, 110, 110, 0.72, 0.2, 0.9, 0.7],
[335, 335, 170, 170, 0.8, 0.3, 0.4, 0.8],
[270, 275, 100, 110, 0.92, 0.6, 0.8, 0.2],
[165, 255, 90, 110, 0.72, 0.1, 0.5, 0.4],
[277, 285, 95, 90, 0.81, 0.3, 0.9, 0.4],
[225, 225, 150, 150, 0.7, 0.3, 0.4, 0.7],
[350, 250, 100, 100, 0.8, 0.9, 0.2, 0.6],
[267, 285, 95, 110, 0.9, 0.2, 0.7, 0.6]])
def get_iou(index, best, center_x, center_y, width, height):
x1 = center_x - width / 2
y1 = center_y - height / 2
x2 = center_x + width / 2
y2 = center_y + height / 2
areas = (y2 - y1 + 1) * (x2 - x1 + 1)
x11 = np.maximum(x1[best], x1[index[1:]])
y11 = np.maximum(y1[best], y1[index[1:]])
x22 = np.minimum(x2[best], x2[index[1:]])
y22 = np.minimum(y2[best], y2[index[1:]])
# 如果边框相交, x22 - x11 > 0, 如果边框不相交, w(h)设为0
w = np.maximum(0, x22 - x11 + 1)
h = np.maximum(0, y22 - y11 + 1)
overlaps = w * h
ious = overlaps / (areas[best] + areas[index[1:]] - overlaps)
return ious
def nms(dets, thresh):
"""
:param dets: numpy矩阵
:param thresh: iou阈值
:return:
"""
result = []
# 3类
for each in range(3):
the_boxes = dets[np.where(dets[:, 5:8].argsort()[:, -1] == each)[0].tolist(), :]
center_x = the_boxes[:, 0]
center_y = the_boxes[:, 1]
width = the_boxes[:, 2]
height = the_boxes[:, 3]
confidence = the_boxes[:, 4]
index = confidence.argsort()[::-1]
keep = []
while index.size > 0:
best = index[0]
keep.append(np.expand_dims(the_boxes[best, :], axis=0))
ious = get_iou(index, best, center_x, center_y, width, height)
idx = np.where(ious <= thresh)[0]
index = index[idx + 1]
result.append(np.concatenate(keep, axis=0))
return np.concatenate(result, axis=0)
def plot_bbox(dets):
center_x = dets[:, 0]
center_y = dets[:, 1]
width = dets[:, 2]
height = dets[:, 3]
class_id = dets[:, 5:8].argsort()[:, -1].tolist()
color_list = ["lime", "magenta", "cyan"]
for i, each in enumerate(class_id):
x1 = int(center_x[i] - width[i] / 2)
y1 = int(center_y[i] - height[i] / 2)
x2 = int(center_x[i] + width[i] / 2)
y2 = int(center_y[i] + height[i] / 2)
c = color_list[each]
plt.plot([x1, x2], [y1, y1], c)
plt.plot([x1, x1], [y1, y2], c)
plt.plot([x1, x2], [y2, y2], c)
plt.plot([x2, x2], [y1, y2], c)
plt.figure(1)
ax1 = plt.subplot(1, 2, 1)
ax2 = plt.subplot(1, 2, 2)
plt.sca(ax1)
# nms之前的框
plot_bbox(boxes)
# nms之后的框
keep = nms(boxes, thresh=0.7)
plt.sca(ax2)
plot_bbox(keep)
plt.show()
如上图,左边是未经过NMS处理的boxes,右边是经过NMS处理的boxes。(不同的颜色代表不同的类别)
结语
如果您有修改意见或问题,欢迎留言或者通过邮箱和我联系。
如果我的文章对您有帮助,转载请注明出处。