使用mmrotate时遇到的BUg:nms_rotated(): incompatible function arguments. The following argument types are s

测试demo的时候,出现了这个错误,

keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
TypeError: nms_rotated(): incompatible function arguments. The following argument types are supported:

解决方法:找到“mmcv\ops\nms.py”文件中的”nms_rotated“函数

替换为以下的代码即可(mmcv.ops.nms — mmcv 1.7.2 文档

def nms_rotated(dets: Tensor,
                scores: Tensor,
                iou_threshold: float,
                labels: Optional[Tensor] = None,
                clockwise: bool = True) -> Tuple[Tensor, Tensor]:
    """Performs non-maximum suppression (NMS) on the rotated boxes according to
    their intersection-over-union (IoU).

    Rotated NMS iteratively removes lower scoring rotated boxes which have an
    IoU greater than iou_threshold with another (higher scoring) rotated box.

    Args:
        dets (torch.Tensor):  Rotated boxes in shape (N, 5).
            They are expected to be in
            (x_ctr, y_ctr, width, height, angle_radian) format.
        scores (torch.Tensor): scores in shape (N, ).
        iou_threshold (float): IoU thresh for NMS.
        labels (torch.Tensor, optional): boxes' label in shape (N,).
        clockwise (bool): flag indicating whether the positive angular
            orientation is clockwise. default True.
            `New in version 1.4.3.`

    Returns:
        tuple: kept dets(boxes and scores) and indice, which is always the
        same data type as the input.
    """
    if dets.shape[0] == 0:
        return dets, None
    if not clockwise:
        flip_mat = dets.new_ones(dets.shape[-1])
        flip_mat[-1] = -1
        dets_cw = dets * flip_mat
    else:
        dets_cw = dets
    multi_label = labels is not None
    if labels is None:
        input_labels = scores.new_empty(0, dtype=torch.int)
    else:
        input_labels = labels
    if dets.device.type in ('npu', 'mlu'):
        order = scores.new_empty(0, dtype=torch.long)
        keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw,
                                           input_labels, iou_threshold,
                                           multi_label)
        dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
                         dim=1)
        return dets, keep_inds

    if multi_label:
        dets_wl = torch.cat((dets_cw, labels.unsqueeze(1)), 1)  # type: ignore
    else:
        dets_wl = dets_cw
    _, order = scores.sort(0, descending=True)
    dets_sorted = dets_wl.index_select(0, order)

    if torch.__version__ == 'parrots':
        keep_inds = ext_module.nms_rotated(
            dets_wl,
            scores,
            order,
            dets_sorted,
            input_labels,
            iou_threshold=iou_threshold,
            multi_label=multi_label)
    else:
        keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
                                           input_labels, iou_threshold,
                                           multi_label)
  
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值