测试demo的时候,出现了这个错误,
keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
TypeError: nms_rotated(): incompatible function arguments. The following argument types are supported:
解决方法:找到“mmcv\ops\nms.py”文件中的”nms_rotated“函数
替换为以下的代码即可(mmcv.ops.nms — mmcv 1.7.2 文档)
def nms_rotated(dets: Tensor,
scores: Tensor,
iou_threshold: float,
labels: Optional[Tensor] = None,
clockwise: bool = True) -> Tuple[Tensor, Tensor]:
"""Performs non-maximum suppression (NMS) on the rotated boxes according to
their intersection-over-union (IoU).
Rotated NMS iteratively removes lower scoring rotated boxes which have an
IoU greater than iou_threshold with another (higher scoring) rotated box.
Args:
dets (torch.Tensor): Rotated boxes in shape (N, 5).
They are expected to be in
(x_ctr, y_ctr, width, height, angle_radian) format.
scores (torch.Tensor): scores in shape (N, ).
iou_threshold (float): IoU thresh for NMS.
labels (torch.Tensor, optional): boxes' label in shape (N,).
clockwise (bool): flag indicating whether the positive angular
orientation is clockwise. default True.
`New in version 1.4.3.`
Returns:
tuple: kept dets(boxes and scores) and indice, which is always the
same data type as the input.
"""
if dets.shape[0] == 0:
return dets, None
if not clockwise:
flip_mat = dets.new_ones(dets.shape[-1])
flip_mat[-1] = -1
dets_cw = dets * flip_mat
else:
dets_cw = dets
multi_label = labels is not None
if labels is None:
input_labels = scores.new_empty(0, dtype=torch.int)
else:
input_labels = labels
if dets.device.type in ('npu', 'mlu'):
order = scores.new_empty(0, dtype=torch.long)
keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw,
input_labels, iou_threshold,
multi_label)
dets = torch.cat((dets[keep_inds], scores[keep_inds].reshape(-1, 1)),
dim=1)
return dets, keep_inds
if multi_label:
dets_wl = torch.cat((dets_cw, labels.unsqueeze(1)), 1) # type: ignore
else:
dets_wl = dets_cw
_, order = scores.sort(0, descending=True)
dets_sorted = dets_wl.index_select(0, order)
if torch.__version__ == 'parrots':
keep_inds = ext_module.nms_rotated(
dets_wl,
scores,
order,
dets_sorted,
input_labels,
iou_threshold=iou_threshold,
multi_label=multi_label)
else:
keep_inds = ext_module.nms_rotated(dets_wl, scores, order, dets_sorted,
input_labels, iou_threshold,
multi_label)

最低0.47元/天 解锁文章
1002

被折叠的 条评论
为什么被折叠?



