《解析 MMDetection 中的 BaseRoIHead》
在目标检测任务中,Region of Interest(RoI)Heads 起着至关重要的作用,负责对感兴趣区域进行进一步的处理和预测。mmdet/models/roi_heads/base_roi_head.py
文件定义了一个基础的 RoI Head 类,为各种具体的 RoI Head 实现提供了抽象的框架。
一、类的继承与元类
BaseRoIHead
类继承自BaseModule
和使用ABCMeta
作为元类。这使得该类既可以作为一个普通的 PyTorch 模块进行初始化和使用,又具有抽象基类的特性,可以通过抽象方法强制子类实现特定的功能。
二、初始化方法
def __init__(self,
bbox_roi_extractor: OptMultiConfig = None,
bbox_head: OptMultiConfig = None,
mask_roi_extractor: OptMultiConfig = None,
mask_head: OptMultiConfig = None,
shared_head: OptConfigType = None,
train_cfg: OptConfigType = None,
test_cfg: OptConfigType = None,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
if shared_head is not None:
self.shared_head = MODELS.build(shared_head)
if bbox_head is not None:
self.init_bbox_head(bbox_roi_extractor, bbox_head)
if mask_head is not None:
self.init_mask_head(mask_roi_extractor, mask_head)
self.init_assigner_sampler()
这个初始化方法接受多个参数,包括用于边界框提取、边界框头部、掩码提取、掩码头部、共享头部以及训练和测试配置等。如果shared_head
不为None
,则使用MODELS.build
函数构建共享头部模块。如果bbox_head
和mask_head
不为None
,分别调用init_bbox_head
和init_mask_head
方法初始化边界框头部和掩码头部。最后,调用init_assigner_sampler
方法初始化分配器和采样器。
三、属性方法
with_bbox
属性:
@property
def with_bbox(self) -> bool:
"""bool: whether the RoI head contains a `bbox_head`"""
return hasattr(self, 'bbox_head') and self.bbox_head is not None
这个属性用于判断当前的 RoI Head 是否包含边界框头部。如果对象具有bbox_head
属性且不为None
,则返回True
,否则返回False
。
2. with_mask
属性:
@property
def with_mask(self) -> bool:
"""bool: whether the RoI head contains a `mask_head`"""
return hasattr(self, 'mask_head') and self.mask_head is not None
类似地,这个属性用于判断 RoI Head 是否包含掩码头部。
3. with_shared_head
属性:
@property
def with_shared_head(self) -> bool:
"""bool: whether the RoI head contains a `shared_head`"""
return hasattr(self, 'shared_head') and self.shared_head is not None
判断 RoI Head 是否包含共享头部。
四、抽象方法
init_bbox_head
方法:
@abstractmethod
def init_bbox_head(self, *args, **kwargs):
"""Initialize ``bbox_head``"""
pass
这是一个抽象方法,要求子类实现用于初始化边界框头部的功能。
2. init_mask_head
方法:
@abstractmethod
def init_mask_head(self, *args, **kwargs):
"""Initialize ``mask_head``"""
pass
抽象方法,用于初始化掩码头部。
3. init_assigner_sampler
方法:
@abstractmethod
def init_assigner_sampler(self, *args, **kwargs):
"""Initialize assigner and sampler."""
pass
要求子类实现初始化分配器和采样器的功能。
五、loss
方法
@abstractmethod
def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList,
batch_data_samples: SampleList):
"""Perform forward propagation and loss calculation of the roi head on
the features of the upstream network."""
这是一个抽象方法,用于在上游网络的特征上执行 RoI Head 的前向传播和损失计算。
六、predict
方法
def predict(self,
x: Tuple[Tensor],
rpn_results_list: InstanceList,
batch_data_samples: SampleList,
rescale: bool = False) -> InstanceList:
assert self.with_bbox, 'Bbox head must be implemented.'
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]
# TODO: nms_op in mmcv need be enhanced, the bbox result may get
# difference when not rescale in bbox_head
# If it has the mask branch, the bbox branch does not need
# to be scaled to the original image scale, because the mask
# branch will scale both bbox and mask at the same time.
bbox_rescale = rescale if not self.with_mask else False
results_list = self.predict_bbox(
x,
batch_img_metas,
rpn_results_list,
rcnn_test_cfg=self.test_cfg,
rescale=bbox_rescale)
if self.with_mask:
results_list = self.predict_mask(
x, batch_img_metas, results_list, rescale=rescale)
return results_list
这个方法用于在给定上游网络的特征、区域提议列表和数据样本列表的情况下进行预测。首先,它检查是否存在边界框头部,如果没有则抛出异常。然后,获取图像的元信息。根据是否有掩码分支来确定边界框是否需要缩放到原始图像尺度。接着,调用predict_bbox
方法进行边界框的预测。如果存在掩码分支,则调用predict_mask
方法进行掩码的预测,并将结果返回。