NMS(非极大值抑制)代码解析

本文详细介绍了YOLO目标检测算法的输出解析过程,包括如何从网络输出中获取目标位置及类别概率,应用非极大值抑制(NMS)进行目标筛选。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

更多代码实现:https://github.com/Tangzixia/Vehicle-Detection-YOLO-ver/blob/master/yolo_tiny.py

def interpret_output(yolo, output):
    probs = np.zeros((7, 7, 2, 20))
    class_probs = np.reshape(output[0:980], (7, 7, 20))
    scales = np.reshape(output[980:1078], (7, 7, 2))
    boxes = np.reshape(output[1078:], (7, 7, 2, 4))
    offset = np.transpose(np.reshape(np.array([np.arange(7)] * 14), (2, 7, 7)), (1, 2, 0))

    boxes[:, :, :, 0] += offset
    boxes[:, :, :, 1] += np.transpose(offset, (1, 0, 2))
    boxes[:, :, :, 0:2] = boxes[:, :, :, 0:2] / 7.0
    boxes[:, :, :, 2] = np.multiply(boxes[:, :, :, 2], boxes[:, :, :, 2])
    boxes[:, :, :, 3] = np.multiply(boxes[:, :, :, 3], boxes[:, :, :, 3])

    boxes[:, :, :, 0] *= yolo.w_img
    boxes[:, :, :, 1] *= yolo.h_img
    boxes[:, :, :, 2] *= yolo.w_img
    boxes[:, :, :, 3] *= yolo.h_img

    for i in range(2):
        for j in range(20):
            probs[:, :, i, j] = np.multiply(class_probs[:, :, j], scales[:, :, i])
    #这里我们得到了7*7*2个框,即有98个框,每个框有自己对应的20个类的得分,即有1960个得分,通过设定阈值删除一些特别低的得分
    filter_mat_probs = np.array(probs >= yolo.threshold, dtype='bool')
    #这儿假设我们得到了1160个得分,因为它们是在[7,7,2,20]的维度上面的表示,所以现在我们用filter_mat_boxes可以1160个“True”得分在7,7,2,20四个维度的得分!
    filter_mat_boxes = np.nonzero(filter_mat_probs)
    boxes_filtered = boxes[filter_mat_boxes[0], filter_mat_boxes[1], filter_mat_boxes[2]]
    #这样将得到1160个得分,即probs_filtered的维度为(1160,)
    probs_filtered = probs[filter_mat_probs]
    #这样将得到每个得分属于哪个类,维度为(1160,)
    classes_num_filtered = np.argmax(filter_mat_probs, axis=3)[
        filter_mat_boxes[0], filter_mat_boxes[1], filter_mat_boxes[2]]

    #对所有的得分进行排序(默认是从小到大的顺序,现在我们转化为了从大到小的顺序)
    argsort = np.array(np.argsort(probs_filtered))[::-1]
    #然后就可以得到从得分从大到小排列的1160个框(可以看做98个框,每个框可以对应1~20数量的一个小框)
    boxes_filtered = boxes_filtered[argsort]
    #可以得到得分由大到小的排列顺序
    probs_filtered = probs_filtered[argsort]
    #调整好对应的类的对应关系
    classes_num_filtered = classes_num_filtered[argsort]

    #对于一个框,自己和自己肯定IOU==1,因此可以删除,然后把和它iou相交大于阈值的框都删除
    for i in range(len(boxes_filtered)):
        if probs_filtered[i] == 0: continue
        for j in range(i + 1, len(boxes_filtered)):
            if iou(boxes_filtered[i], boxes_filtered[j]) > yolo.iou_threshold:
                probs_filtered[j] = 0.0

    #到最后,剩下的框对应的就是NMS处理后的框
    filter_iou = np.array(probs_filtered > 0.0, dtype='bool')
    boxes_filtered = boxes_filtered[filter_iou]
    probs_filtered = probs_filtered[filter_iou]
    classes_num_filtered = classes_num_filtered[filter_iou]

    result = []
    for i in range(len(boxes_filtered)):
        result.append(
            [yolo.classes[classes_num_filtered[i]], boxes_filtered[i][0], boxes_filtered[i][1], boxes_filtered[i][2],
             boxes_filtered[i][3], probs_filtered[i]])

    return result
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值