详解yolov8的nms中multi-label功能为什么不是真正的multi-label任务实现

文章探讨了YOLOv8中的非极大抑制(NMS)在多标签分类任务中的应用,指出NMS本质上并非真正的多标签网络。同时,详细解释了v8的lossfunction计算过程,尤其是TaskAlignedAssigner在样本分配和loss计算中的关键作用,强调了模型对每个像素最大面积真实目标框的依赖,限制了对重叠多类别框的训练能力。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、什么是multi-label?

多标签分类(Multilabel classification): 给每个样本一系列的目标标签,即表示的是样本各属性而不是相互排斥的。比如图片中的猫可同时拥有两个标签cat、animal,需要预测出一个概念集合。

2.一般思路如何实现multi-label任务?

要实现这个任务,一种是使用多个模型,可以并行使用两个模型分别预测同一个物体,每个模型对该物体的预测不同。即一个模型预测图片中的猫为cat,另一个预测其为animal。这种方法比较简单实用,但可能满足不了一些场合的单一模型推理要求。

一种是专门设计一个网络同时对物体带有的多个标签进行训练,设计思路:
1.从网络的数据集、输入、损失函数、标签分配策略进行修改。
2.类似multi-task网络的形式,对网络输出做分支并行。

(两种实现方法并不是本文所讲述主题,一言带过~)

二、yolov8中nms函数的multi-label

首先放一段v8中nms源码

def non_max_suppression(
        prediction,
        conf_thres=0.25,
        iou_thres=0.45,
        classes=None,
        agnostic=False,
        multi_label=False,
        labels=(),
        max_det=300,
        nc=0,  # number of classes (optional)
        max_time_img=0.05,
        max_nms=30000,
        max_wh=7680,
):
    """
    Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.

    Args:
        prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
            containing the predicted boxes, classes, and masks. The tensor should be in the format
            output by a model, such as YOLO.
        conf_thres (float): The confidence threshold below which boxes will be filtered out.
            Valid values are between 0.0 and 1.0.
        iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
            Valid values are between 0.0 and 1.0.
        classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
        agnostic (bool): If True, the model is agnostic to the number of classes, and all
            classes will be considered as one.
        multi_label (bool): If True, each box may have multiple labels.
        labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
            list contains the apriori labels for a given image. The list should be in the format
            output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
        max_det (int): The maximum number of boxes to keep after NMS.
        nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
        max_time_img (float): The maximum time (seconds) for processing one image.
        max_nms (int): The maximum number of boxes into torchvision.ops.nms().
        max_wh (int): The maximum box width and height in pixels

    Returns:
        (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
            shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
            (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
    """

    # Checks
    assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
    assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
    if isinstance(prediction, (list, tuple)):  # YOLOv8 model in validation model, output = (inference_out, loss_out)
        prediction = prediction[0]  # select only inference output

    device = prediction.device
    mps = 'mps' in device.type  # Apple MPS
    if mps:  # MPS not fully supported yet, convert tensors to CPU before NMS
        prediction = prediction.cpu()
    bs = prediction.shape[0]  # batch size
    nc = nc or (prediction.shape[1] - 4)  # number of classes
    nm = prediction.shape[1] - nc - 4
    mi = 4 + nc  # mask
### YOLOv8 中集成 DIoU-NMS 的方法 要在 YOLOv8 后处理中实现 DIoU-NMS 方法,可以通过修改其源码来完成。以下是具体的操作方式: #### 修改 `non_max_suppression` 函数 YOLOv8 使用的非极大值抑制 (NMS) 实现通常位于工具模块中(如 `ultralytics/yolo/utils/ops.py` 或类似的文件路径),其中定义了一个名为 `non_max_suppression` 的函数。为了支持 DIoU-NMS,需要调整该函数。 以下是一个基于 YOLOv5 和 YOLOv8 的代码示例,展示如何将 DIoU-NMS 集成到后处理阶段[^1]: ```python import torch def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=True, diou_nms=True): """ 执行非极大值抑制 (NMS),可以选择使用 DIoU-NMS。 参数: prediction: 模型预测结果 (batch_size, num_boxes, [xywh, obj_conf, class_scores]) conf_thres: 置信度阈值 iou_thres: IoU 阈值 classes: 是否过滤特定类别 agnostic: 类别无关 NMS multi_label: 多标签分类模式 diou_nms: 是否启用 DIoU-NMS 返回: list of detections, on (n,6) tensor per image [xyxy, conf, cls] """ # 初始化变量 nc = prediction.shape[2] - 5 # 数字类别的数量 xc = prediction[..., 4] > conf_thres # 候选框筛选条件 # 输出列表初始化 output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] for xi, x in enumerate(prediction): # 对每张图片的结果进行迭代 x = x[xc[xi]] # 过滤掉低置信度候选框 if not x.shape[0]: continue # 计算 xyxy 格式的边界框坐标 box = xywh2xyxy(x[:, :4]) # 如果启用了多标签,则按类别分别处理 if multi_label: i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) else: # 单一最高分数类别 conf, j = x[:, 5:].max(1, keepdim=True) x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] # 排序并执行 NMS n = x.shape[0] # 当前图像中的检测数 if not n: continue c = x[:, 5:6] * (0 if agnostic else max_wh) # 类别偏移 boxes, scores = x[:, :4] + c, x[:, 4] # 调整后的盒子和得分 if diou_nms: from utils.metrics import bbox_iou # 导入 DIoU 计算函数 ious = bbox_iou(boxes.unsqueeze(1), boxes.unsqueeze(0), x1y1x2y2=False, CIoU=True) # 计算 DIoU 矩阵 selected_indices = [] while True: _, idx = scores.max(0) selected_indices.append(idx.item()) if len(selected_indices) >= n or scores[idx] < conf_thres: break mask = ious[idx] <= iou_thres scores = scores[mask.squeeze()] boxes = boxes[mask.squeeze()] x = x[selected_indices] else: from torchvision.ops import nms indices = nms(boxes, scores, iou_thres) x = x[indices] output[xi] = x return output ``` #### 关键点说明 1. **DIoU 计算**: 上述代码通过调用 `bbox_iou` 函数实现了 DIoU 的计算逻辑[^1]。如果未找到对应的函数,可以从 YOLOv5 的 `utils/metrics.py` 文件复制其实现。 2. **自定义 NMS 循环**: 在启用 DIoU-NMS 的情况下,手动构建循环以逐步选择最优目标,并剔除与其他目标重叠率超过设定阈值的对象。 3. **兼容性**: 新增参数 `diou_nms` 控制是否启用 DIoU-NMS,默认设置为 False,以便保持原有功能不变。 --- ### 注意事项 - 若发现性能瓶颈,建议测试不同硬件平台下的运行效率,因为 DIoU-NMS 可能会稍微增加计算开销。 - 确保导入的辅助函数(如 `bbox_iou`)已正确配置于项目目录下。 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值