上文详解yolov8的nms中multi-label功能为什么不是真正的multi-label任务实现说到在yolov8的detection训练过程都是以每个像素最大面积真实目标框来作为loss的target来计算损失率和训练的,所以无法对重叠的多类别框进行训练,即使由nms产生的multi-label框,也不是真正的multi-label网络。
训练过程中,决定是否同时训练两个标签的是标签分配策略TaskAlignedAssigner过程。标签分配最关键的一步是select_highest_overlaps函数,它将获得在每个像素点上面积最大的唯一GT,target_gt_idx。往后target_labels, target_bboxes, target_scores都由self.get_targets函数产生。
那么在select_highest_overlaps函数中是什么决定或产生了最后的target_gt_idx呢?在代码中可以很明显看出:
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
# (b, n_max_boxes, h*w) -> (b, h*w)
fg_mask = mask_pos.sum(-2)
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1]) # (b, n_max_boxes, h*w)
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
is_max_overlaps = F.one_hot(max_overlaps_idx, n_max_boxes) # (b, h*w, n_max_boxes)
is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype) # (b, n_max_boxes, h*w)
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos) # (b, n_max_boxes, h*w)
fg_mask = mask_pos.sum(-2)
# find each grid serve which gt(index)
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
return target_