测试demo的时候,出现了这个错误,
keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
TypeError: nms_rotated(): incompatible function arguments. The following argument types are supported:
解决方法:找到“mmcv\ops\nms.py”文件中的”nms_rotated“函数
替换为以下的代码即可(mmcv.ops.nms — mmcv 1.7.2 文档)
def nms_rotated(dets: Tensor,
scores: Tensor,
iou_threshold: float,
labels: Optional[Tensor] = None,
clockwise: bool = True) -> Tuple[Tensor, Tensor]:
"""Performs non-maximum suppression (NMS) on the rotated boxes according to
their intersection-over-union (IoU).
Rotated NMS iteratively removes lower scoring rotated boxes which have an
IoU greater than iou_threshold with another (higher scoring) rotated box.
Args:
dets (torch.Tensor): Rotated boxes in shape (N, 5).
They are expected to be in
(x_ctr, y_ctr, width, height, angle_radian) format.
scores (torch.Tensor): scores in shape (N, ).
iou_threshold (float): IoU thresh for NMS.
labels (torch.Tensor, optional): boxes' label in shape (N,).
clockwise (bool): flag indicating whether the positive angular
orientation is clockwise. default True.
`New in version 1.4.3.`
Returns:
tuple: kept dets(boxes and scores) and indice, which is always the
same data type as the input.
"""
if dets.shape[0] == 0:
return dets, None
if not clockwise:
flip_mat = dets.new_ones(dets.shape[-1])
flip_mat[-1] = -1
dets_cw = dets * flip_mat
else:
dets_cw = dets
multi_label = labels is not None
if labels is None:
input_labels = scores.new_empty(0, dtype=torch.int)
else:
input_labels = labels
if dets.device.type in ('npu', 'mlu'):
order = scores.new_empty(0, dtype=torch.long)
keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw,
input_labels, iou_threshold,
multi_label)
dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
dim=1)
return dets, keep_inds
if multi_label:
dets_wl = torch.cat((dets_cw, labels.unsqueeze(1)), 1) # type: ignore
else:
dets_wl = dets_cw
_, order = scores.sort(0, descending=True)
dets_sorted = dets_wl.index_select(0, order)
if torch.__version__ == 'parrots':
keep_inds = ext_module.nms_rotated(
dets_wl,
scores,
order,
dets_sorted,
input_labels,
iou_threshold=iou_threshold,
multi_label=multi_label)
else:
keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
input_labels, iou_threshold,
multi_label)
dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
dim=1)
return dets, keep_inds
主要原因:新版本的这个函数中少传了一个“input_labels”参数
# 新版本的代码
keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
iou_threshold, multi_label)
#旧版本的代码
keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
input_labels, iou_threshold, multi_label)
完整的错误信息如下:
ETA:D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\__init__.py:20: UserWarning: On January 1, 2023, MMCV will release v2.0.0, in which it will remove components related to the training process and add a data transformation module. In addition, it will rename the package names mmcv to mmcv-lite and mmcv-full to mmcv. See https://github.com/open-mmlab/mmcv/blob/master/docs/en/compatibility.md for more details.
warnings.warn(
D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\__init__.py:20: UserWarning: On January 1, 2023, MMCV will release v2.0.0, in which it will remove components related to the training process and add a data transformation module. In addition, it will rename the package names mmcv to mmcv-lite and mmcv-full to mmcv. See https://github.com/open-mmlab/mmcv/blob/master/docs/en/compatibility.md for more details.
warnings.warn(
Traceback (most recent call last):
File "D:/PythonProject/mmrotate-main/tools/train.py", line 194, in <module>
main()
File "D:/PythonProject/mmrotate-main/tools/train.py", line 183, in main
train_detector(
File "D:\PythonProject\mmrotate-main\mmrotate\apis\train.py", line 144, in train_detector
runner.run(data_loaders, cfg.workflow)
File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 136, in run
epoch_runner(data_loaders[i], **kwargs)
File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 58, in train
self.call_hook('after_train_epoch')
File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\base_runner.py", line 317, in call_hook
getattr(hook, fn_name)(self)
File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\hooks\evaluation.py", line 271, in after_train_epoch
self._do_evaluate(runner)
File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmdet\core\evaluation\eval_hooks.py", line 60, in _do_evaluate
results = single_gpu_test(runner.model, self.dataloader, show=False)
File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmdet\apis\test.py", line 29, in single_gpu_test
result = model(return_loss=False, rescale=True, **data)
File "D:\anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "D:\anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\parallel\data_parallel.py", line 51, in forward
return super().forward(*inputs, **kwargs)
File "D:\anaconda3\envs\openmmlab\lib\site-packages\torch\nn\parallel\data_parallel.py", line 183, in forward
return self.module(*inputs[0], **module_kwargs[0])
File "D:\anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "D:\anaconda3\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\fp16_utils.py", line 119, in new_func
return old_func(*args, **kwargs)
File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmdet\models\detectors\base.py", line 174, in forward
return self.forward_test(img, img_metas, **kwargs)
File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmdet\models\detectors\base.py", line 147, in forward_test
return self.simple_test(imgs[0], img_metas[0], **kwargs)
File "D:\PythonProject\mmrotate-main\mmrotate\models\detectors\two_stage.py", line 183, in simple_test
return self.roi_head.simple_test(
File "D:\PythonProject\mmrotate-main\mmrotate\models\roi_heads\roi_trans_roi_head.py", line 333, in simple_test
det_bbox, det_label = self.bbox_head[-1].get_bboxes(
File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\runner\fp16_utils.py", line 208, in new_func
return old_func(*args, **kwargs)
File "D:\PythonProject\mmrotate-main\mmrotate\models\roi_heads\bbox_heads\rotated_bbox_head.py", line 418, in get_bboxes
det_bboxes, det_labels = multiclass_nms_rotated(
File "D:\PythonProject\mmrotate-main\mmrotate\core\post_processing\bbox_nms_rotated.py", line 80, in multiclass_nms_rotated
_, keep = nms_rotated(bboxes_for_nms, scores, nms.iou_thr)
File "D:\anaconda3\envs\openmmlab\lib\site-packages\mmcv\ops\nms.py", line 473, in nms_rotated
keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
TypeError: nms_rotated(): incompatible function arguments. The following argument types are supported:
1. (dets: torch.Tensor, scores: torch.Tensor, order: torch.Tensor, dets_sorted: torch.Tensor, labels: torch.Tensor, iou_threshold: float, multi_label: int) -> torch.Tensor
Invoked with: tensor([[ 1.3929e+03, 5.8900e+02, 4.4586e+02, 1.4374e+02, -5.9560e-02],
[ 1.3778e+03, 5.9055e+02, 4.9475e+02, 1.3340e+02, 4.2467e-03],
[ 1.3880e+03, 5.8972e+02, 4.8726e+02, 1.3731e+02, -9.5322e-02],
[ 1.3828e+03, 5.9207e+02, 4.6885e+02, 1.3934e+02, 6.5348e-02],
[ 1.3763e+03, 5.9193e+02, 4.8936e+02, 1.3456e+02, 6.2798e-02],
[ 1.3828e+03, 5.8986e+02, 4.9943e+02, 1.3395e+02, -5.6785e-02],
[ 1.3844e+03, 5.8931e+02, 5.1308e+02, 1.3199e+02, -9.3728e-02],
[ 1.3776e+03, 5.8879e+02, 5.1078e+02, 1.3006e+02, -1.0194e-01],
[ 1.3818e+03, 5.9021e+02, 5.0582e+02, 1.3181e+02, -3.0323e-02],
[ 1.3793e+03, 5.9176e+02, 4.1503e+02, 1.4637e+02, -9.3242e-02],
[ 1.3907e+03, 5.9019e+02, 4.8419e+02, 1.2927e+02, -4.7811e-02],
[ 1.3931e+03, 5.8952e+02, 5.0725e+02, 1.3281e+02, -1.4515e-01],
[ 1.3969e+03, 5.8813e+02, 4.8288e+02, 1.3330e+02, 7.8660e-02],
[ 1.3843e+03, 5.9290e+02, 4.3540e+02, 1.4532e+02, -5.1058e-02],
[ 1.4419e+03, 5.8991e+02, 3.7714e+02, 1.3363e+02, -1.8755e-01],
[ 1.4195e+03, 5.9087e+02, 4.5214e+02, 1.3127e+02, -9.4837e-02],
[ 1.3749e+03, 5.8739e+02, 4.5484e+02, 1.4451e+02, -5.0357e-02],
[ 1.4041e+03, 5.8720e+02, 4.8444e+02, 1.2699e+02, -1.0760e-01],
[ 1.4243e+03, 5.8959e+02, 4.2513e+02, 1.3117e+02, -1.7885e-01],
[ 1.3769e+03, 5.9002e+02, 5.0865e+02, 1.3009e+02, -1.0078e-01],
[ 1.3739e+03, 5.9448e+02, 4.7890e+02, 1.4622e+02, 4.9585e-02],
[ 1.3815e+03, 5.8972e+02, 5.1432e+02, 1.2936e+02, -5.7829e-02],
[ 1.3822e+03, 5.8850e+02, 4.9508e+02, 1.3048e+02, -3.3254e-02],
[ 1.3725e+03, 5.9071e+02, 4.4382e+02, 1.3830e+02, -1.3818e-01],
[ 1.3911e+03, 5.8871e+02, 5.6314e+02, 1.3148e+02, 6.5076e-02],
[ 1.3914e+03, 5.8967e+02, 4.9326e+02, 1.3072e+02, -1.1746e-01],
[ 1.3912e+03, 5.8884e+02, 5.1746e+02, 1.3296e+02, -1.0046e-01]],
device='cuda:0'), tensor([0.4558, 0.9813, 0.8270, 0.9195, 0.9672, 0.9919, 0.9903, 0.9746, 0.9700,
0.0748, 0.5303, 0.9840, 0.4899, 0.0847, 0.1594, 0.0835, 0.1354, 0.0516,
0.1017, 0.6517, 0.0837, 0.1967, 0.3002, 0.0909, 0.0872, 0.7068, 0.8541],
device='cuda:0'), tensor([ 5, 6, 11, 1, 7, 8, 4, 3, 26, 2, 25, 19, 10, 12, 0, 22, 21, 14,
16, 18, 23, 24, 13, 20, 15, 9, 17], device='cuda:0'), tensor([[ 1.3828e+03, 5.8986e+02, 4.9943e+02, 1.3395e+02, -5.6785e-02],
[ 1.3844e+03, 5.8931e+02, 5.1308e+02, 1.3199e+02, -9.3728e-02],
[ 1.3931e+03, 5.8952e+02, 5.0725e+02, 1.3281e+02, -1.4515e-01],
[ 1.3778e+03, 5.9055e+02, 4.9475e+02, 1.3340e+02, 4.2467e-03],
[ 1.3776e+03, 5.8879e+02, 5.1078e+02, 1.3006e+02, -1.0194e-01],
[ 1.3818e+03, 5.9021e+02, 5.0582e+02, 1.3181e+02, -3.0323e-02],
[ 1.3763e+03, 5.9193e+02, 4.8936e+02, 1.3456e+02, 6.2798e-02],
[ 1.3828e+03, 5.9207e+02, 4.6885e+02, 1.3934e+02, 6.5348e-02],
[ 1.3912e+03, 5.8884e+02, 5.1746e+02, 1.3296e+02, -1.0046e-01],
[ 1.3880e+03, 5.8972e+02, 4.8726e+02, 1.3731e+02, -9.5322e-02],
[ 1.3914e+03, 5.8967e+02, 4.9326e+02, 1.3072e+02, -1.1746e-01],
[ 1.3769e+03, 5.9002e+02, 5.0865e+02, 1.3009e+02, -1.0078e-01],
[ 1.3907e+03, 5.9019e+02, 4.8419e+02, 1.2927e+02, -4.7811e-02],
[ 1.3969e+03, 5.8813e+02, 4.8288e+02, 1.3330e+02, 7.8660e-02],
[ 1.3929e+03, 5.8900e+02, 4.4586e+02, 1.4374e+02, -5.9560e-02],
[ 1.3822e+03, 5.8850e+02, 4.9508e+02, 1.3048e+02, -3.3254e-02],
[ 1.3815e+03, 5.8972e+02, 5.1432e+02, 1.2936e+02, -5.7829e-02],
[ 1.4419e+03, 5.8991e+02, 3.7714e+02, 1.3363e+02, -1.8755e-01],
[ 1.3749e+03, 5.8739e+02, 4.5484e+02, 1.4451e+02, -5.0357e-02],
[ 1.4243e+03, 5.8959e+02, 4.2513e+02, 1.3117e+02, -1.7885e-01],
[ 1.3725e+03, 5.9071e+02, 4.4382e+02, 1.3830e+02, -1.3818e-01],
[ 1.3911e+03, 5.8871e+02, 5.6314e+02, 1.3148e+02, 6.5076e-02],
[ 1.3843e+03, 5.9290e+02, 4.3540e+02, 1.4532e+02, -5.1058e-02],
[ 1.3739e+03, 5.9448e+02, 4.7890e+02, 1.4622e+02, 4.9585e-02],
[ 1.4195e+03, 5.9087e+02, 4.5214e+02, 1.3127e+02, -9.4837e-02],
[ 1.3793e+03, 5.9176e+02, 4.1503e+02, 1.4637e+02, -9.3242e-02],
[ 1.4041e+03, 5.8720e+02, 4.8444e+02, 1.2699e+02, -1.0760e-01]],
device='cuda:0'), 0.1, False