前言:
在FasterRCNN或者MaskRCNN中,在通过RPN筛选出大小不同的一系列预测边框后,需要在Box_Head层对选出的预测边框进行进一步筛选,去除掉边框IoU介于两种阈值之间的预测边框,然后再使用ROI Pooling将处理后的预测边框池化为大小一致的边框。 然后对这些边框进行进一步的特征提取,之后再在提取后的特征上进行进一步的边框预测,得到每一个预测边框的类别得分以及它们的边框回归信息,并以此来得到最终的预测边框以及其类别信息。其代码详解如下:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch import nn
from .roi_box_feature_extractors import make_roi_box_feature_extractor
from .roi_box_predictors import make_roi_box_predictor
from .inference import make_roi_box_post_processor
from .loss import make_roi_box_loss_evaluator
class ROIBoxHead(torch.nn.Module):
"""
Generic Box Head class.
"""
def __init__(self, cfg, in_channels):
super(ROIBoxHead, self).__init__()
# 指定ROI层中box_head模块的特征提取类
self.feature_extractor = make_roi_box_feature_extractor(cfg, in_cha