前言
本篇将介绍build_roi_box_head()函数,这个函数是在your_project/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py文件中,
def build_roi_box_head(cfg, in_channels):
"""
Constructs a new box head.
By default, uses ROIBoxHead, but if it turns out not to be enough, just register a new class
and make it a parameter in the config
"""
# 主要返回一个ROIBoxHead类对象
return ROIBoxHead(cfg, in_channels)
一、ROIBoxHead类
从代码可知,build_roi_box_head()主要是返回一个ROIBoxHead类对象,看来我们主要需要了解的目标就是这个ROIBoxHead类了,这个类也是在your_project/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py文件中,我们首先看一下__init__()函数:
class ROIBoxHead(torch.nn.Module):
"""
Generic Box Head class.
"""
def __init__(self, cfg, in_channels):
super(ROIBoxHead, self).__init__()
# ROI层中的特征提取器(先进行ROI Align,后续有没有特征提取操作看具体head的方法)
# 因为RPN提取的Proposals大小都不太一样,为了使得这些Proposals的图片特征大小一样,
# 需要进行ROI Align操作得到大小一样的特征。
self.feature_extractor = make_roi_box_feature_extractor(cfg, in_channels)
# ROI层中的边框预测类(用于类别的分类和box的回归~)
self.predictor = make_roi_box_predictor(
cfg, self.feature_extractor.out_channels)
# 下面这两个和RPN中的很相像
# ROI层中的后处理类(inference过程 进行NMS操作和box解码等操作)
self.post_processor = make_roi_box_post_processor(cfg)
# 训练过程计算loss
self.loss_evaluator = make_roi_box_loss_evaluator(cfg)
接着我们看一下该类的forward()函数,了解该类的一个处理流程:
def forward(self, features, proposals, targets=None):
"""
Arguments:
features (list[Tensor]): feature-maps from possibly several levels
proposa