使用mmrotate时遇到的BUg:nms_rotated(): incompatible function arguments. The following argument types are s

测试demo的时候,出现了这个错误,

keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
TypeError: nms_rotated(): incompatible function arguments. The following argument types are supported:

解决方法:找到“mmcv\ops\nms.py”文件中的”nms_rotated“函数

替换为以下的代码即可(mmcv.ops.nms — mmcv 1.7.2 文档

def nms_rotated(dets: Tensor,
                scores: Tensor,
                iou_threshold: float,
                labels: Optional[Tensor] = None,
                clockwise: bool = True) -> Tuple[Tensor, Tensor]:
    """Performs non-maximum suppression (NMS) on the rotated boxes according to
    their intersection-over-union (IoU).

    Rotated NMS iteratively removes lower scoring rotated boxes which have an
    IoU greater than iou_threshold with another (higher scoring) rotated box.

    Args:
        dets (torch.Tensor):  Rotated boxes in shape (N, 5).
            They are expected to be in
            (x_ctr, y_ctr, width, height, angle_radian) format.
        scores (torch.Tensor): scores in shape (N, ).
        iou_threshold (float): IoU thresh for NMS.
        labels (torch.Tensor, optional): boxes' label in shape (N,).
        clockwise (bool): flag indicating whether the positive angular
            orientation is clockwise. default True.
            `New in version 1.4.3.`

    Returns:
        tuple: kept dets(boxes and scores) and indice, which is always the
        same data type as the input.
    """
    if dets.shape[0] == 0:
        return dets, None
    if not clockwise:
        flip_mat = dets.new_ones(dets.shape[-1])
        flip_mat[-1] = -1
        dets_cw = dets * flip_mat
    else:
        dets_cw = dets
    multi_label = labels is not None
    if labels is None:
        input_labels = scores.new_empty(0, dtype=torch.int)
    else:
        input_labels = labels
    if dets.device.type in ('npu', 'mlu'):
        order = scores.new_empty(0, dtype=torch.long)
        keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw,
                                           input_labels, iou_threshold,
                                           multi_label)
        dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
                         dim=1)
        return dets, keep_inds

    if multi_label:
        dets_wl = torch.cat((dets_cw, labels.unsqueeze(1)), 1)  # type: ignore
    else:
        dets_wl = dets_cw
    _, order = scores.sort(0, descending=True)
    dets_sorted = dets_wl.index_select(0, order)

    if torch.__version__ == 'parrots':
        keep_inds = ext_module.nms_rotated(
            dets_wl,
            scores,
            order,
            dets_sorted,
            input_labels,
            iou_threshold=iou_threshold,
            multi_label=multi_label)
    else:
        keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
                                           input_labels, iou_threshold,
                                           multi_label)
    dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
                     dim=1)
    return dets, keep_inds

主要原因:新版本的这个函数中少传了一个“input_labels”参数

# 新版本的代码
keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
                                    iou_threshold, multi_label)

#旧版本的代码
keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
                                   input_labels, iou_threshold, multi_label)

完整的错误信息如下:

 ETA:D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\__init__.py:20: UserWarning: On January 1, 2023, MMCV will release v2.0.0, in which it will remove components related to the training process and add a data transformation module. In addition, it will rename the package names mmcv to mmcv-lite and mmcv-full to mmcv. See https://github.com/open-mmlab/mmcv/blob/master/docs/en/compatibility.md for more details.
  warnings.warn(
D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\__init__.py:20: UserWarning: On January 1, 2023, MMCV will release v2.0.0, in which it will remove components related to the training process and add a data transformation module. In addition, it will rename the package names mmcv to mmcv-lite and mmcv-full to mmcv. See https://github.com/open-mmlab/mmcv/blob/master/docs/en/compatibility.md for more details.
  warnings.warn(
Traceback (most recent call last):
  File "D:/PythonProject/mmrotate-main/tools/train.py", line 194, in <module>
    main()
  File "D:/PythonProject/mmrotate-main/tools/train.py", line 183, in main
    train_detector(
  File "D:\PythonProject\mmrotate-main\mmrotate\apis\train.py", line 144, in train_detector
    runner.run(data_loaders, cfg.workflow)
  File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 136, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 58, in train
    self.call_hook('after_train_epoch')
  File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\base_runner.py", line 317, in call_hook
    getattr(hook, fn_name)(self)
  File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\hooks\evaluation.py", line 271, in after_train_epoch
    self._do_evaluate(runner)
  File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmdet\core\evaluation\eval_hooks.py", line 60, in _do_evaluate
    results = single_gpu_test(runner.model, self.dataloader, show=False)
  File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmdet\apis\test.py", line 29, in single_gpu_test
    result = model(return_loss=False, rescale=True, **data)
  File "D:\anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "D:\anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\parallel\data_parallel.py", line 51, in forward
    return super().forward(*inputs, **kwargs)
  File "D:\anaconda3\envs\openmmlab\lib\site-packages\torch\nn\parallel\data_parallel.py", line 183, in forward
    return self.module(*inputs[0], **module_kwargs[0])
  File "D:\anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "D:\anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\fp16_utils.py", line 119, in new_func
    return old_func(*args, **kwargs)
  File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmdet\models\detectors\base.py", line 174, in forward
    return self.forward_test(img, img_metas, **kwargs)
  File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmdet\models\detectors\base.py", line 147, in forward_test
    return self.simple_test(imgs[0], img_metas[0], **kwargs)
  File "D:\PythonProject\mmrotate-main\mmrotate\models\detectors\two_stage.py", line 183, in simple_test
    return self.roi_head.simple_test(
  File "D:\PythonProject\mmrotate-main\mmrotate\models\roi_heads\roi_trans_roi_head.py", line 333, in simple_test
    det_bbox, det_label = self.bbox_head[-1].get_bboxes(
  File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\fp16_utils.py", line 208, in new_func
    return old_func(*args, **kwargs)
  File "D:\PythonProject\mmrotate-main\mmrotate\models\roi_heads\bbox_heads\rotated_bbox_head.py", line 418, in get_bboxes
    det_bboxes, det_labels = multiclass_nms_rotated(
  File "D:\PythonProject\mmrotate-main\mmrotate\core\post_processing\bbox_nms_rotated.py", line 80, in multiclass_nms_rotated
    _, keep = nms_rotated(bboxes_for_nms, scores, nms.iou_thr)
  File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\ops\nms.py", line 473, in nms_rotated
    keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
TypeError: nms_rotated(): incompatible function arguments. The following argument types are supported:
    1. (dets: torch.Tensor, scores: torch.Tensor, order: torch.Tensor, dets_sorted: torch.Tensor, labels: torch.Tensor, iou_threshold: float, multi_label: int) -> torch.Tensor

Invoked with: tensor([[ 1.3929e+03,  5.8900e+02,  4.4586e+02,  1.4374e+02, -5.9560e-02],
        [ 1.3778e+03,  5.9055e+02,  4.9475e+02,  1.3340e+02,  4.2467e-03],
        [ 1.3880e+03,  5.8972e+02,  4.8726e+02,  1.3731e+02, -9.5322e-02],
        [ 1.3828e+03,  5.9207e+02,  4.6885e+02,  1.3934e+02,  6.5348e-02],
        [ 1.3763e+03,  5.9193e+02,  4.8936e+02,  1.3456e+02,  6.2798e-02],
        [ 1.3828e+03,  5.8986e+02,  4.9943e+02,  1.3395e+02, -5.6785e-02],
        [ 1.3844e+03,  5.8931e+02,  5.1308e+02,  1.3199e+02, -9.3728e-02],
        [ 1.3776e+03,  5.8879e+02,  5.1078e+02,  1.3006e+02, -1.0194e-01],
        [ 1.3818e+03,  5.9021e+02,  5.0582e+02,  1.3181e+02, -3.0323e-02],
        [ 1.3793e+03,  5.9176e+02,  4.1503e+02,  1.4637e+02, -9.3242e-02],
        [ 1.3907e+03,  5.9019e+02,  4.8419e+02,  1.2927e+02, -4.7811e-02],
        [ 1.3931e+03,  5.8952e+02,  5.0725e+02,  1.3281e+02, -1.4515e-01],
        [ 1.3969e+03,  5.8813e+02,  4.8288e+02,  1.3330e+02,  7.8660e-02],
        [ 1.3843e+03,  5.9290e+02,  4.3540e+02,  1.4532e+02, -5.1058e-02],
        [ 1.4419e+03,  5.8991e+02,  3.7714e+02,  1.3363e+02, -1.8755e-01],
        [ 1.4195e+03,  5.9087e+02,  4.5214e+02,  1.3127e+02, -9.4837e-02],
        [ 1.3749e+03,  5.8739e+02,  4.5484e+02,  1.4451e+02, -5.0357e-02],
        [ 1.4041e+03,  5.8720e+02,  4.8444e+02,  1.2699e+02, -1.0760e-01],
        [ 1.4243e+03,  5.8959e+02,  4.2513e+02,  1.3117e+02, -1.7885e-01],
        [ 1.3769e+03,  5.9002e+02,  5.0865e+02,  1.3009e+02, -1.0078e-01],
        [ 1.3739e+03,  5.9448e+02,  4.7890e+02,  1.4622e+02,  4.9585e-02],
        [ 1.3815e+03,  5.8972e+02,  5.1432e+02,  1.2936e+02, -5.7829e-02],
        [ 1.3822e+03,  5.8850e+02,  4.9508e+02,  1.3048e+02, -3.3254e-02],
        [ 1.3725e+03,  5.9071e+02,  4.4382e+02,  1.3830e+02, -1.3818e-01],
        [ 1.3911e+03,  5.8871e+02,  5.6314e+02,  1.3148e+02,  6.5076e-02],
        [ 1.3914e+03,  5.8967e+02,  4.9326e+02,  1.3072e+02, -1.1746e-01],
        [ 1.3912e+03,  5.8884e+02,  5.1746e+02,  1.3296e+02, -1.0046e-01]],
       device='cuda:0'), tensor([0.4558, 0.9813, 0.8270, 0.9195, 0.9672, 0.9919, 0.9903, 0.9746, 0.9700,
        0.0748, 0.5303, 0.9840, 0.4899, 0.0847, 0.1594, 0.0835, 0.1354, 0.0516,
        0.1017, 0.6517, 0.0837, 0.1967, 0.3002, 0.0909, 0.0872, 0.7068, 0.8541],
       device='cuda:0'), tensor([ 5,  6, 11,  1,  7,  8,  4,  3, 26,  2, 25, 19, 10, 12,  0, 22, 21, 14,
        16, 18, 23, 24, 13, 20, 15,  9, 17], device='cuda:0'), tensor([[ 1.3828e+03,  5.8986e+02,  4.9943e+02,  1.3395e+02, -5.6785e-02],
        [ 1.3844e+03,  5.8931e+02,  5.1308e+02,  1.3199e+02, -9.3728e-02],
        [ 1.3931e+03,  5.8952e+02,  5.0725e+02,  1.3281e+02, -1.4515e-01],
        [ 1.3778e+03,  5.9055e+02,  4.9475e+02,  1.3340e+02,  4.2467e-03],
        [ 1.3776e+03,  5.8879e+02,  5.1078e+02,  1.3006e+02, -1.0194e-01],
        [ 1.3818e+03,  5.9021e+02,  5.0582e+02,  1.3181e+02, -3.0323e-02],
        [ 1.3763e+03,  5.9193e+02,  4.8936e+02,  1.3456e+02,  6.2798e-02],
        [ 1.3828e+03,  5.9207e+02,  4.6885e+02,  1.3934e+02,  6.5348e-02],
        [ 1.3912e+03,  5.8884e+02,  5.1746e+02,  1.3296e+02, -1.0046e-01],
        [ 1.3880e+03,  5.8972e+02,  4.8726e+02,  1.3731e+02, -9.5322e-02],
        [ 1.3914e+03,  5.8967e+02,  4.9326e+02,  1.3072e+02, -1.1746e-01],
        [ 1.3769e+03,  5.9002e+02,  5.0865e+02,  1.3009e+02, -1.0078e-01],
        [ 1.3907e+03,  5.9019e+02,  4.8419e+02,  1.2927e+02, -4.7811e-02],
        [ 1.3969e+03,  5.8813e+02,  4.8288e+02,  1.3330e+02,  7.8660e-02],
        [ 1.3929e+03,  5.8900e+02,  4.4586e+02,  1.4374e+02, -5.9560e-02],
        [ 1.3822e+03,  5.8850e+02,  4.9508e+02,  1.3048e+02, -3.3254e-02],
        [ 1.3815e+03,  5.8972e+02,  5.1432e+02,  1.2936e+02, -5.7829e-02],
        [ 1.4419e+03,  5.8991e+02,  3.7714e+02,  1.3363e+02, -1.8755e-01],
        [ 1.3749e+03,  5.8739e+02,  4.5484e+02,  1.4451e+02, -5.0357e-02],
        [ 1.4243e+03,  5.8959e+02,  4.2513e+02,  1.3117e+02, -1.7885e-01],
        [ 1.3725e+03,  5.9071e+02,  4.4382e+02,  1.3830e+02, -1.3818e-01],
        [ 1.3911e+03,  5.8871e+02,  5.6314e+02,  1.3148e+02,  6.5076e-02],
        [ 1.3843e+03,  5.9290e+02,  4.3540e+02,  1.4532e+02, -5.1058e-02],
        [ 1.3739e+03,  5.9448e+02,  4.7890e+02,  1.4622e+02,  4.9585e-02],
        [ 1.4195e+03,  5.9087e+02,  4.5214e+02,  1.3127e+02, -9.4837e-02],
        [ 1.3793e+03,  5.9176e+02,  4.1503e+02,  1.4637e+02, -9.3242e-02],
        [ 1.4041e+03,  5.8720e+02,  4.8444e+02,  1.2699e+02, -1.0760e-01]],
       device='cuda:0'), 0.1, False
 

File "<stdin>", line 1, in <module> File "C:\Users\86018\anaconda3\envs\cuihui\lib\site-packages\mmdet\apis\inference.py", line 157, in inference_detector results = model(return_loss=False, rescale=True, **data) File "C:\Users\86018\anaconda3\envs\cuihui\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "C:\Users\86018\anaconda3\envs\cuihui\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "C:\Users\86018\anaconda3\envs\cuihui\lib\site-packages\mmcv\runner\fp16_utils.py", line 119, in new_func return old_func(*args, **kwargs) File "C:\Users\86018\anaconda3\envs\cuihui\lib\site-packages\mmdet\models\detectors\base.py", line 174, in forward return self.forward_test(img, img_metas, **kwargs) File "C:\Users\86018\anaconda3\envs\cuihui\lib\site-packages\mmdet\models\detectors\base.py", line 147, in forward_test return self.simple_test(imgs[0], img_metas[0], **kwargs) File "D:\CH_files\硕士\科研\代码\LWGANet-main\detection\mmrotate\models\detectors\two_stage.py", line 183, in simple_test return self.roi_head.simple_test( File "D:\CH_files\硕士\科研\代码\LWGANet-main\detection\mmrotate\models\roi_heads\rotate_standard_roi_head.py", line 252, in simple_test det_bboxes, det_labels = self.simple_test_bboxes( File "D:\CH_files\硕士\科研\代码\LWGANet-main\detection\mmrotate\models\roi_heads\oriented_standard_roi_head.py", line 178, in simple_test_bboxes File "D:\CH_files\硕士\科研\代码\LWGANet-main\detection\mmrotate\models\roi_heads\rotate_standar d_roi_head.py", line 252, in simple_test det_bboxes, det_labels = self.simple_test_bboxes( File "D:\CH_files\硕士\科研\代码\LWGANet-main\detection\mmrotate\models\roi_heads\oriented_stand ard_roi_head.py", line 178, in simple_test_bboxes det_bbox, det_label = self.bbox_head.get_bboxes( File "C:\Users\86018\anaconda3\envs\cuihui\lib\site-packages\mmcv\runner\fp16_utils.py", line 20 8, in new_func return old_func(*args, **kwargs) File "D:\CH_files\硕士\科研\代码\LWGANet-main\detection\mmrotate\models\roi_heads\bbox_heads\rot ated_bbox_head.py", line 418, in get_bboxes det_bboxes, det_labels = multiclass_nms_rotated( File "D:\CH_files\硕士\科研\代码\LWGANet-main\detection\mmrotate\core\post_processing\bbox_nms_r otated.py", line 58, in multiclass_nms_rotated bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds] RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
03-24
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值