maskrcnn_benchmark 代码详解之 inference.py

本文详细解析了maskrcnn_benchmark库中inference.py文件的RPNPostProcessor类,该类用于从FPN不同层提取目标边框。通过参数如pre_nms_top_n、post_nms_top_n、nms_thresh和min_size进行边框筛选和NMS操作。同时介绍了make_rpn_postprocessor函数,用于根据用户参数实例化RPNPostProcessor。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

前言:

  在maskrcnn_benchmark中,实现从FPN各个层里提取好的边框的功能由inference.py实现。在inference.py中定义了RPNPostProcessor类,该类利用一些提前定义好的参数来在多个FPN层(P2-P5)或者指定的输出特征层(比如C5)选取边框。

常用的参数有:pre_nms_top_n为每一个特征层上选取的边框数;post_nms_top_n为在得到的所有层边框中进一步选择的边框数;

nms_thresh为非极大线性抑制(NMS)的阈值;min_size为所能接受的最小的边框大小;

  RPNPostProcessor类的代码及其解释如下:

class RPNPostProcessor(torch.nn.Module):
    """
    Performs post-processing on the outputs of the RPN boxes, before feeding the
    proposals to the heads
    """

    def __init__(
        self,
        pre_nms_top_n,
        post_nms_top_n,
        nms_thresh,
        min_size,
        box_coder=None,
        fpn_post_nms_top_n=None,
        fpn_post_nms_per_batch=True,
    ):
        """
        Arguments:
            pre_nms_top_n (int)
            post_nms_top_n (int)
            nms_thresh (float)
            min_size (int)
            box_coder (BoxCoder)
            fpn_post_nms_top_n (int)
        """
        super(RPNPostProcessor, self).__init__()
        # 初始化各种变量
        self.pre_nms_top_n = pre_nms_top_n
        self.post_nms_top_n = post_nms_top_n
        self.nms_thresh = nms_thresh
        self.min_size = min_size

        if box_coder is None:
            box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
        self.box_coder = box_coder

        if fpn_post_nms_top_n is None:
            fpn_post_nms_top_n = post_nms_top_n
        self.fpn_post_nms_top_n = fpn_post_nms_top_n
        self.fpn_post_nms_per_batch = fpn_post_nms_per_batch

    def add_gt_proposals(self, proposals, targets):
        """
        Arguments:
            proposals: list[BoxList]:预测边框列表
            targets: list[BoxList]:基准边框列表
        """
        # Get the device we're operating on,获得运算所在的设备
        device = proposals[0].bbox.device
        # 获得基准边框
        gt_boxes = [target.copy_with_fields([]) for target in targets]

        # later cat of bbox requires all fields to be present for all bbox
        # 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值