2021SC@SDUSC
post_process.py
该程序采用的nms算法,去掉模型预测后的多余框。
def nms(dets, thresh):
"""Apply classic DPM-style greedy NMS."""
if dets.shape[0] == 0:
return dets[[], :]
scores = dets[:, 0]
x1 = dets[:, 1]
y1 = dets[:, 2]
x2 = dets[:, 3]
y2 = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
ndets = dets.shape[0]
suppressed = np.zeros((ndets), dtype=np.int)
# nominal indices
# _i, _j
# sorted indices
# i, j
# temp variables for box i's (the box currently under consideration)
# ix1, iy1, ix2, iy2, iarea
# variables for computing overlap with box j (lower scoring box)
# xx1, yy1, xx2, yy2
# w, h
# inter, ovr
for _i in range(ndets):
i = order[_i]
if suppressed[i] == 1:
continue
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i]
for _j in range(_i + 1, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0.0, xx2 - xx1 + 1)
h = max(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (iarea + areas[j] - inter)
if ovr >= thresh:
suppressed[j] = 1
keep = np.where(suppressed == 0)[0]
dets = dets[keep, :]
return dets
该程序为具体的nms算法的实现方法。将模型预测矩阵输入。
nms算法流程如下:
- 选取这类box中scores最大的哪一个,记为box_best,并保留它
- 计算box_best与其余的box的IOU
- 如果其IOU>0.5了,那么就舍弃这个box(由于可能这两个box表示同一目标,所以保留分数高的哪一个)
- 从最后剩余的boxes中,再找出最大scores的哪一个,如此循环往复
nms算法的思想是搜素局部最大值,抑制极大值。
def soft_nms(dets, sigma, thres):
dets_final = []
while len(dets) > 0:
maxpos = np.argmax(dets[:, 0])
dets_final.append(dets[maxpos].copy())
ts, tx1, ty1, tx2, ty2 = dets[maxpos]
scores = dets[:, 0]
# force remove bbox at maxpos
scores[maxpos] = -1
x1 = dets[:, 1]
y1 = dets[:, 2]
x2 = dets[:, 3]
y2 = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
xx1 = np.maximum(tx1, x1)
yy1 = np.maximum(ty1, y1)
xx2 = np.minimum(tx2, x2)
yy2 = np.minimum(ty2, y2)
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (areas + areas[maxpos] - inter)
weight = np.exp(-(ovr * ovr) / sigma)
scores = scores * weight
idx_keep = np.where(scores >= thres)
dets[:, 0] = scores
dets = dets[idx_keep]
dets_final = np.array(dets_final).reshape(-1, 5)
return dets_final
该方法为nms算法的改进方法,NMS是为了去除重复的预测框而设计的算法,但会出现物体重叠等问题,故使用soft_nms级nms的改进算法。
下图为soft_nms算法流程
def bbox_area(box):
w = box[2] - box[0] + 1
h = box[3] - box[1] + 1
return w * h
def bbox_overlaps(x, y):
N = x.shape[0]
K = y.shape[0]
overlaps = np.zeros((N, K), dtype=np.float32)
for k in range(K):
y_area = bbox_area(y[k])
for n in range(N):
iw = min(x[n, 2], y[k, 2]) - max(x[n, 0], y[k, 0]) + 1
if iw > 0:
ih = min(x[n, 3], y[k, 3]) - max(x[n, 1], y[k, 1]) + 1
if ih > 0:
x_area = bbox_area(x[n])
ua = x_area + y_area - iw * ih
overlaps[n, k] = iw * ih / ua
return overlaps
def box_voting(nms_dets, dets, vote_thresh):
top_dets = nms_dets.copy()
top_boxes = nms_dets[:, 1:]
all_boxes = dets[:, 1:]
all_scores = dets[:, 0]
top_to_all_overlaps = bbox_overlaps(top_boxes, all_boxes)
for k in range(nms_dets.shape[0]):
inds_to_vote = np.where(top_to_all_overlaps[k] >= vote_thresh)[0]
boxes_to_vote = all_boxes[inds_to_vote, :]
ws = all_scores[inds_to_vote]
top_dets[k, 1:] = np.average(boxes_to_vote, axis=0, weights=ws)
return top_dets
该代码段定义了预测框获取尺寸、重叠、集成的方法,其中集成方法是根据nms算法将top预测框与所有框进行重叠操作。
def get_nms_result(boxes,
scores,
config,
num_classes,
background_label=0,
labels=None):
has_labels = labels is not None
cls_boxes = [[] for _ in range(num_classes)]
start_idx = 1 if background_label == 0 else 0
for j in range(start_idx, num_classes):
inds = np.where(labels == j)[0] if has_labels else np.where(
scores[:, j] > config['score_thresh'])[0]
scores_j = scores[inds] if has_labels else scores[inds, j]
boxes_j = boxes[inds, :] if has_labels else boxes[inds, j * 4:(j + 1) *
4]
dets_j = np.hstack((scores_j[:, np.newaxis], boxes_j)).astype(
np.float32, copy=False)
if config.get('use_soft_nms', False):
nms_dets = soft_nms(dets_j, config['sigma'], config['nms_thresh'])
else:
nms_dets = nms(dets_j, config['nms_thresh'])
if config.get('enable_voting', False):
nms_dets = box_voting(nms_dets, dets_j, config['vote_thresh'])
#add labels
label = np.array([j for _ in range(len(nms_dets))])
nms_dets = np.hstack((label[:, np.newaxis], nms_dets)).astype(
np.float32, copy=False)
cls_boxes[j] = nms_dets
# Limit to max_per_image detections **over all classes**
image_scores = np.hstack(
[cls_boxes[j][:, 1] for j in range(start_idx, num_classes)])
if len(image_scores) > config['detections_per_im']:
image_thresh = np.sort(image_scores)[-config['detections_per_im']]
for j in range(start_idx, num_classes):
keep = np.where(cls_boxes[j][:, 1] >= image_thresh)[0]
cls_boxes[j] = cls_boxes[j][keep, :]
im_results = np.vstack(
[cls_boxes[j] for j in range(start_idx, num_classes)])
return im_results
该方法是利用nms算法的到处理后的预测结果。
def mstest_box_post_process(result, config, num_classes):
"""
Multi-scale Test
Only available for batch_size=1 now.
"""
post_bbox = {}
use_flip = False
ms_boxes = []
ms_scores = []
im_shape = result['im_shape'][0]
for k in result.keys():
if 'bbox' in k:
boxes = result[k][0]
boxes = np.reshape(boxes, (-1, 4 * num_classes))
scores = result['score' + k[4:]][0]
if 'flip' in k:
boxes = box_flip(boxes, im_shape)
use_flip = True
ms_boxes.append(boxes)
ms_scores.append(scores)
ms_boxes = np.concatenate(ms_boxes)
ms_scores = np.concatenate(ms_scores)
bbox_pred = get_nms_result(ms_boxes, ms_scores, config, num_classes)
post_bbox.update({'bbox': (bbox_pred, [[len(bbox_pred)]])})
if use_flip:
bbox = bbox_pred[:, 2:]
bbox_flip = np.append(
bbox_pred[:, :2], box_flip(bbox, im_shape), axis=1)
post_bbox.update({'bbox_flip': (bbox_flip, [[len(bbox_flip)]])})
return post_bbox
该方法是多尺度单元框检测的后处理函数。
def corner_post_process(results, config, num_classes):
detections = results['bbox'][0]
keep_inds = (detections[:, 1] > -1)
detections = detections[keep_inds]
labels = detections[:, 0]
scores = detections[:, 1]
boxes = detections[:, 2:6]
cls_boxes = get_nms_result(
boxes, scores, config, num_classes, background_label=-1, labels=labels)
results.update({'bbox': (cls_boxes, [[len(cls_boxes)]])})
该方法是用于CornerNet的后处理函数。
参数:
results
:
dict{‘bbox’, ‘im_id’},检测结果;其中字典中的属性含义如下:bbox
: 检测框信息,包含类别和坐标信息im_id
: 图像id