【NMS】nms_multiclass.m

本文探讨了在YOLO等网络中使用的多类别非极大抑制(NMS)算法,详细介绍了其工作原理和实现过程,特别是在大量候选框处理上的优化策略。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在Faster RCNN中常见的NMS是在rpn的最后一步,合并重复的IOU用的,由于RPN的预测只有前景和背景之分,所以目前我看到的代码中所用的NMS都是不是多类别的NMS,但是在YOLO这种网络中应该需要用到多类别的NMS。

function picks = nms_multiclass(boxes, overlap)

%%boxes为一个m*n的矩阵,其中m为boundingbox的个数,n的前4列为每个boundingbox的坐标,格式为
%%(x1,y1,x2,y2);第5:n列为每一类的置信度,一共n-5+1个置信度

% top = nms(boxes, overlap)
% Non-maximum suppression. (FAST VERSION)
% Greedily select high-scoring detections and skip detections
% that are significantly covered by a previously selected
% detection.
%
% NOTE: This is adapted from Pedro Felzenszwalb's version (nms.m),
% but an inner loop has been eliminated to significantly speed it
% up in the case of a large number of boxes

% Copyright (C) 2011-12 by Tomasz Malisiewicz
% All rights reserved.
% 
% This file is part of the Exemplar-SVM library and is made
% available under the terms of the MIT license (see COPYING file).
% Project homepage: https://github.com/quantombone/exemplarsvm


if isempty(boxes)
  picks = {};
  return;
end

if size(boxes, 1) < 10000
    picks = nms_multiclass_mex(double(boxes), double(overlap));
    return;
end

x1 = boxes(:,1);
y1 = boxes(:,2);
x2 = boxes(:,3);
y2 = boxes(:,4);

area = (x2-x1+1) .* (y2-y1+1);

picks = cell(size(boxes, 2)-4, 1);
%在不同类别内分别做NMS
for iS = 5:size(boxes, 2)
    s = boxes(:,iS);
    [~, I] = sort(s);

    pick = s*0;
    counter = 1;
    while ~isempty(I)
      last = length(I);
      i = I(last);  
      pick(counter) = i;
      counter = counter + 1;

      xx1 = max(x1(i), x1(I(1:last-1)));
      yy1 = max(y1(i), y1(I(1:last-1)));
      xx2 = min(x2(i), x2(I(1:last-1)));
      yy2 = min(y2(i), y2(I(1:last-1)));

      w = max(0.0, xx2-xx1+1);
      h = max(0.0, yy2-yy1+1);

      inter = w.*h;
      o = inter ./ (area(i) + area(I(1:last-1)) - inter);

      I = I(o<=overlap);
    end

    pick = pick(1:(counter-1));
    picks{iS-4} = pick;
end

 

torchpack dist-run -np 1 python tools/train.py configs/nuscenes/det/transfusion/secfpn/camera+lidar/swint_v0p075/convfuser.yaml --model.encoders.camera.backbone.init_cfg.checkpoint pretrained/swint-nuimages-pretrained.pth --load_from pretrained/lidar-only-det.pth Invalid MIT-MAGIC-COOKIE-1 keyTraceback (most recent call last): File "tools/train.py", line 15, in <module> from mmdet3d.datasets import build_dataset File "/media/wangbaihui/1ecf654b-afad-4dab-af7b-e34b00dda87a/bevfusion/mmdet3d/datasets/__init__.py", line 4, in <module> from .custom_3d import * File "/media/wangbaihui/1ecf654b-afad-4dab-af7b-e34b00dda87a/bevfusion/mmdet3d/datasets/custom_3d.py", line 10, in <module> from ..core.bbox import get_box_type File "/media/wangbaihui/1ecf654b-afad-4dab-af7b-e34b00dda87a/bevfusion/mmdet3d/core/__init__.py", line 4, in <module> from .post_processing import * # noqa: F401, F403 File "/media/wangbaihui/1ecf654b-afad-4dab-af7b-e34b00dda87a/bevfusion/mmdet3d/core/post_processing/__init__.py", line 5, in <module> from .box3d_nms import aligned_3d_nms, box3d_multiclass_nms, circle_nms File "/media/wangbaihui/1ecf654b-afad-4dab-af7b-e34b00dda87a/bevfusion/mmdet3d/core/post_processing/box3d_nms.py", line 1, in <module> import numba File "/home/wangbaihui/anaconda3/envs/vad/lib/python3.8/site-packages/numba/__init__.py", line 55, in <module> _ensure_critical_deps() File "/home/wangbaihui/anaconda3/envs/vad/lib/python3.8/site-packages/numba/__init__.py", line 40, in _ensure_critical_deps raise ImportError(msg) ImportError: Numba needs NumPy 1.22 or greater. Got NumPy 1.19. -------------------------------------------------------------------------- Primary job terminated normally, but 1 process returned a non-zero exit code. Per user-direction, the job has been aborted. ----------------------------------------------------------------------
最新发布
07-22
D:\pythonProject1\PaddleDetection-release-2.8.1>python tools/export_model.py -c configs/ppyolo/ppyolo_r18vd_coco.yml --output_dir ./inference_model -o weights=tools/output/249.pdparams 信息: 用提供的模式无法找到文件。 Warning: Unable to use JDE/FairMOT/ByteTrack, please install lap, for example: `pip install lap`, see https://github.com/gatagat/lap Warning: Unable to use numba in PP-Tracking, please install numba, for example(python3.7): `pip install numba==0.56.4` Warning: Unable to use numba in PP-Tracking, please install numba, for example(python3.7): `pip install numba==0.56.4` [06/05 14:50:46] ppdet.utils.checkpoint INFO: Skipping import of the encryption module. Warning: Unable to use MOT metric, please install motmetrics, for example: `pip install motmetrics`, see https://github.com/longcw/py-motmetrics Warning: Unable to use MCMOT metric, please install motmetrics, for example: `pip install motmetrics`, see https://github.com/longcw/py-motmetrics [06/05 14:50:47] ppdet.utils.checkpoint INFO: Finish loading model weights: tools/output/249.pdparams Traceback (most recent call last): File "D:\pythonProject1\PaddleDetection-release-2.8.1\tools\export_model.py", line 148, in <module> main() File "D:\pythonProject1\PaddleDetection-release-2.8.1\tools\export_model.py", line 144, in main run(FLAGS, cfg) File "D:\pythonProject1\PaddleDetection-release-2.8.1\tools\export_model.py", line 105, in run trainer.export(FLAGS.output_dir, for_fd=FLAGS.for_fd) File "D:\pythonProject1\PaddleDetection-release-2.8.1\ppdet\engine\trainer.py", line 1294, in export static_model, pruned_input_spec, input_spec = self._get_infer_cfg_and_input_spec( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\pythonProject1\PaddleDetection-release-2.8.1\ppdet\engine\trainer.py", line 1240, in _get_infer_cfg_and_input_spec static_model, pruned_input_spec = self._model_to_static(model, input_spec, prune_input) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "D:\pythonProject1\PaddleDetection-release-2.8.1\ppdet\engine\trainer.py", line 1155, in _model_to_static input_spec, static_model.forward.main_program, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Admin\AppData\Roaming\Python\Python311\site-packages\paddle\jit\dy2static\program_translator.py", line 1118, in main_program concrete_program = self.concrete_program ^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Admin\AppData\Roaming\Python\Python311\site-packages\paddle\jit\dy2static\program_translator.py", line 1002, in concrete_program return self.concrete_program_specify_input_spec(input_spec=None) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Admin\AppData\Roaming\Python\Python311\site-packages\paddle\jit\dy2static\program_translator.py", line 1046, in concrete_program_specify_input_spec concrete_program, _ = self.get_concrete_program( ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Admin\AppData\Roaming\Python\Python311\site-packages\paddle\jit\dy2static\program_translator.py", line 935, in get_concrete_program concrete_program, partial_program_layer = self._program_cache[ ^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Admin\AppData\Roaming\Python\Python311\site-packages\paddle\jit\dy2static\program_translator.py", line 1694, in __getitem__ self._caches[item_id] = self._build_once(item) ^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Admin\AppData\Roaming\Python\Python311\site-packages\paddle\jit\dy2static\program_translator.py", line 1631, in _build_once concrete_program = ConcreteProgram.pir_from_func_spec( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Admin\AppData\Roaming\Python\Python311\site-packages\decorator.py", line 235, in fun return caller(func, *(extras + args), **kw) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Admin\AppData\Roaming\Python\Python311\site-packages\paddle\base\wrapped_decorator.py", line 40, in __impl__ return wrapped_func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Admin\AppData\Roaming\Python\Python311\site-packages\paddle\base\dygraph\base.py", line 101, in __impl__ return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "C:\Users\Admin\AppData\Roaming\Python\Python311\site-packages\paddle\jit\dy2static\program_translator.py", line 1302, in pir_from_func_spec error_data.raise_new_exception() File "C:\Users\Admin\AppData\Roaming\Python\Python311\site-packages\paddle\jit\dy2static\error.py", line 454, in raise_new_exception raise new_exception from None TypeError: In transformed code: File "D:\pythonProject1\PaddleDetection-release-2.8.1\ppdet\modeling\architectures\meta_arch.py", line 59, in forward if self.training: File "D:\pythonProject1\PaddleDetection-release-2.8.1\ppdet\modeling\architectures\meta_arch.py", line 69, in forward for inp in inputs_list: File "D:\pythonProject1\PaddleDetection-release-2.8.1\ppdet\modeling\architectures\meta_arch.py", line 76, in forward outs.append(self.get_pred()) File "D:\pythonProject1\PaddleDetection-release-2.8.1\ppdet\modeling\architectures\yolo.py", line 150, in get_pred return self._forward() File "D:\pythonProject1\PaddleDetection-release-2.8.1\ppdet\modeling\architectures\yolo.py", line 92, in _forward if self.training: File "D:\pythonProject1\PaddleDetection-release-2.8.1\ppdet\modeling\architectures\yolo.py", line 103, in _forward if self.for_mot: File "D:\pythonProject1\PaddleDetection-release-2.8.1\ppdet\modeling\architectures\yolo.py", line 115, in _forward if self.return_idx: File "D:\pythonProject1\PaddleDetection-release-2.8.1\ppdet\modeling\architectures\yolo.py", line 119, in _forward elif self.post_process is not None: File "D:\pythonProject1\PaddleDetection-release-2.8.1\ppdet\modeling\architectures\yolo.py", line 121, in _forward bbox, bbox_num, nms_keep_idx = self.post_process( File "D:\pythonProject1\PaddleDetection-release-2.8.1\ppdet\modeling\post_process.py", line 69, in __call__ if self.nms is not None: File "D:\pythonProject1\PaddleDetection-release-2.8.1\ppdet\modeling\post_process.py", line 71, in __call__ bbox_pred, bbox_num, before_nms_indexes = self.nms(bboxes, score, File "D:\pythonProject1\PaddleDetection-release-2.8.1\ppdet\modeling\layers.py", line 605, in __call__ def __call__(self, bbox, score, *args): return ops.matrix_nms( ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE bboxes=bbox, scores=score, File "D:\pythonProject1\PaddleDetection-release-2.8.1\ppdet\modeling\ops.py", line 714, in matrix_nms helper.append_op( File "C:\Users\Admin\AppData\Roaming\Python\Python311\site-packages\paddle\base\layer_helper.py", line 57, in append_op return self.main_program.current_block().append_op(*args, **kwargs) File "C:\Users\Admin\AppData\Roaming\Python\Python311\site-packages\paddle\base\framework.py", line 4701, in append_op op = Operator( File "C:\Users\Admin\AppData\Roaming\Python\Python311\site-packages\paddle\base\framework.py", line 3329, in __init__ raise TypeError( TypeError: The type of '%BBoxes' in operator matrix_nms should be one of [str, bytes, Variable]. but received : Value(define_op_name=pd_op.concat, index=0, dtype=tensor<-1x3840x4xf32>, stop_gradient=False) 中文回答
06-06
File "<stdin>", line 1, in <module> File "C:\Users\86018\anaconda3\envs\cuihui\lib\site-packages\mmdet\apis\inference.py", line 157, in inference_detector results = model(return_loss=False, rescale=True, **data) File "C:\Users\86018\anaconda3\envs\cuihui\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "C:\Users\86018\anaconda3\envs\cuihui\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "C:\Users\86018\anaconda3\envs\cuihui\lib\site-packages\mmcv\runner\fp16_utils.py", line 119, in new_func return old_func(*args, **kwargs) File "C:\Users\86018\anaconda3\envs\cuihui\lib\site-packages\mmdet\models\detectors\base.py", line 174, in forward return self.forward_test(img, img_metas, **kwargs) File "C:\Users\86018\anaconda3\envs\cuihui\lib\site-packages\mmdet\models\detectors\base.py", line 147, in forward_test return self.simple_test(imgs[0], img_metas[0], **kwargs) File "D:\CH_files\硕士\科研\代码\LWGANet-main\detection\mmrotate\models\detectors\two_stage.py", line 183, in simple_test return self.roi_head.simple_test( File "D:\CH_files\硕士\科研\代码\LWGANet-main\detection\mmrotate\models\roi_heads\rotate_standard_roi_head.py", line 252, in simple_test det_bboxes, det_labels = self.simple_test_bboxes( File "D:\CH_files\硕士\科研\代码\LWGANet-main\detection\mmrotate\models\roi_heads\oriented_standard_roi_head.py", line 178, in simple_test_bboxes File "D:\CH_files\硕士\科研\代码\LWGANet-main\detection\mmrotate\models\roi_heads\rotate_standar d_roi_head.py", line 252, in simple_test det_bboxes, det_labels = self.simple_test_bboxes( File "D:\CH_files\硕士\科研\代码\LWGANet-main\detection\mmrotate\models\roi_heads\oriented_stand ard_roi_head.py", line 178, in simple_test_bboxes det_bbox, det_label = self.bbox_head.get_bboxes( File "C:\Users\86018\anaconda3\envs\cuihui\lib\site-packages\mmcv\runner\fp16_utils.py", line 20 8, in new_func return old_func(*args, **kwargs) File "D:\CH_files\硕士\科研\代码\LWGANet-main\detection\mmrotate\models\roi_heads\bbox_heads\rot ated_bbox_head.py", line 418, in get_bboxes det_bboxes, det_labels = multiclass_nms_rotated( File "D:\CH_files\硕士\科研\代码\LWGANet-main\detection\mmrotate\core\post_processing\bbox_nms_r otated.py", line 58, in multiclass_nms_rotated bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds] RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
03-24
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值