例子是bbox_overlaps计算anchors和gts的iou。
输入: anchors: (N, 4) ndarray of float
gt_boxes: (K, 4) ndarray of float
输出: overlaps: (N, K) ndarray of overlap between boxes and query_boxes
首先想到for i in range(N)
for j in range(K)
overlaps[i][i] = ****
矩阵思维:创建新的维度
#anchors(N, 4) _>boxes(N, K, 4)
#gt_boxes: (K, 4) _> query_boxes(N, K, 4)

def bbox_overlaps(anchors, gt_boxes):
"""
anchors: (N, 4) ndarray of float
gt_boxes: (K, 4) ndarray of float
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
"""
N = anchors.size(0)
K = gt_boxes.size(0)
#gt_boxes_area (1, K)
gt_boxes_area = ((gt_boxes[:,2] - gt_boxes[:,0] + 1) *
(gt_boxes[:,3] - gt_boxes[:,1] + 1)).view(1, K)
#anchors_area (N, 1)
anchors_area = ((anchors[:,2] - anchors[:,0] + 1) *
(anchors[:,3] - anchors[:,1] + 1)).view(N, 1)
#anchors(N, 4) _>boxes(N, K, 4)
boxes = anchors.view(N, 1, 4).expand(N, K, 4)
#gt_boxes: (K, 4) _> query_boxes(N, K, 4)
query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4)
iw = (torch.min(boxes[:,:,2], query_boxes[:,:,2]) -
torch.max(boxes[:,:,0], query_boxes[:,:,0]) + 1)
iw[iw < 0] = 0
ih = (torch.min(boxes[:,:,3], query_boxes[:,:,3]) -
torch.max(boxes[:,:,1], query_boxes[:,:,1]) + 1)
ih[ih < 0] = 0
#broad (1, K)+(N, 1)= (N, K) - (N, K)
ua = anchors_area + gt_boxes_area - (iw * ih)
overlaps = iw * ih / ua
return overlaps

本文深入解析bbox_overlaps函数,该函数用于计算anchors和ground truth boxes之间的IoU(交并比)。详细介绍了输入参数anchors和gt_boxes的处理过程,以及如何通过矩阵运算高效地计算出(N,K)维度的overlap矩阵。
7万+

被折叠的 条评论
为什么被折叠?



