Center Point源码详细解读3

根据前面的博客我们应该已经知道了center point的模型基本结构与运行流程,前面讲到第一阶段的检测结构,这篇博客将介绍第二阶段的检测。

1、特征提取bird_eye_view

第一阶段经过了 SpMiddleResNetFHD、RPN(Region Proposal Network)、 CenterHead之后,将成的特征作为第二阶段的输入,使用 BEVFeatureExtractor 模块提取 BEV 特征。

位置:CenterPoint-master\det3d\models\second_stage\bird_eye_view.py

代码:

class BEVFeatureExtractor(nn.Module): 
    def __init__(self, pc_start, 
            voxel_size, out_stride):
        super().__init__(
        self.pc_start = pc_start 
        self.voxel_size = voxel_size
        self.out_stride = out_stride

    def absl_to_relative(self, absolute):
        a1 = (absolute[..., 0] - self.pc_start[0]) / self.voxel_size[0] / self.out_stride 
        a2 = (absolute[..., 1] - self.pc_start[1]) / self.voxel_size[1] / self.out_stride 

        return a1, a2

    def forward(self, example, batch_centers, num_point):
        batch_size = len(example['bev_feature'])
        ret_maps = [] 

        for batch_idx in range(batch_size):
            xs, ys = self.absl_to_relative(batch_centers[batch_idx])
            
            # N x C 
            feature_map = bilinear_interpolate_torch(example['bev_feature'][batch_idx],
             xs, ys)

            if num_point > 1:
                section_size = len(feature_map) // num_point
                feature_map = torch.cat([feature_map[i*section_size: (i+1)*section_size] for i in range(num_point)], dim=1)

            ret_maps.append(feature_map)

        return ret_maps 

这一部分的代码比较简单,很容易就能看明白。建议大家还是从前向传播函数来看。

作用:用于提取基于 BEV(Bird's Eye View,鸟瞰图)的特征,返回 ret_maps,即每个批次中每个中心坐标附近的特征图列表。BEVFeatureExtractor 接受第一阶段检测器的输出作为输入,同时接受中心坐标和采样点数作为参数,生成中心坐标附近的特征图。

2、第二阶段检测器RoIHead

Center point使用 RoIHead 作为第二阶段的检测器,输入特征包括 BEV 特征和其他信息,同时也使用 RoIHead 对提取的特征图进行分类和回归,生成最终的检测结果。

位置:CenterPoint-master\det3d\models\roi_heads\roi_head.py

代码:

class RoIHead(RoIHeadTemplate):
    def __init__(self, input_channels, model_cfg, num_class=1, code_size=7, add_box_param=False, test_cfg=None):
        super().__init__(num_class=num_class, model_cfg=model_cfg)
        self.model_cfg = model_cfg
        self.test_cfg = test_cfg 
        self.code_size = code_size
        self.add_box_param = add_box_param

        pre_channel = input_channels

        shared_fc_list = []
        for k in range(0, self.model_cfg.SHARED_FC.__len__()):
            shared_fc_list.extend([
                nn.Conv1d(pre_channel, self.model_cfg.SHARED_FC[k], kernel_size=1, bias=False),
                nn.BatchNorm1d(self.model_cfg.SHARED_FC[k]),
                nn.ReLU()
            ])
            pre_channel = self.model_cfg.SHARED_FC[k]

            if k != self.model_cfg.SHARED_FC.__len__() - 1 and self.model_cfg.DP_RATIO > 0:
                shared_fc_list.append(nn.Dropout(self.model_cfg.DP_RATIO))

        self.shared_fc_layer = nn.Sequential(*shared_fc_list)

        self.cls_layers = self.make_fc_layers(
            input_channels=pre_channel, output_channels=self.num_class, fc_list=self.model_cfg.CLS_FC
        )
        self.reg_layers = self.make_fc_layers(
            input_channels=pre_channel,
            output_channels=code_size,
            fc_list=self.model_cfg.REG_FC
        )
        self.init_weights(weight_init='xavier')

    def init_weights(self, weight_init='xavier'):
        if weight_init == 'kaiming':
            init_func = nn.init.kaiming_normal_
        elif weight_init == 'xavier':
            init_func = nn.init.xavier_normal_
        elif weight_init == 'normal':
            init_func = nn.init.normal_
        else:
            raise NotImplementedError

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
                if weight_init == 'normal':
                    init_func(m.weight, mean=0, std=0.001)
                else:
                    init_func(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        nn.init.normal_(self.reg_layers[-1].weight, mean=0, std=0.001)

    def forward(self, batch_dict, training=True):
        """
        :param input_data: input dict
        :return:
        """
        batch_dict['batch_size'] = len(batch_dict['rois'])
        if training:
            targets_dict = self.assign_targets(batch_dict)
            batch_dict['rois'] = targets_dict['rois']
            batch_dict['roi_labels'] = targets_dict['roi_labels']
            batch_dict['roi_features'] = targets_dict['roi_features']
            batch_dict['roi_scores'] = targets_dict['roi_scores']

        # RoI aware pooling
        if self.add_box_param:
            batch_dict['roi_features'] = torch.cat([batch_dict['roi_features'], batch_dict['rois'], batch_dict['roi_scores'].unsqueeze(-1)], dim=-1)

        pooled_features = batch_dict['roi_features'].reshape(-1, 1,
            batch_dict['roi_features'].shape[-1]).contiguous()  # (BxN, 1, C)

        batch_size_rcnn = pooled_features.shape[0]
        pooled_features = pooled_features.permute(0, 2, 1).contiguous() # (BxN, C, 1)

        shared_features = self.shared_fc_layer(pooled_features.view(batch_size_rcnn, -1, 1))
        rcnn_cls = self.cls_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1)  # (B, 1 or 2)
        rcnn_reg = self.reg_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1)  # (B, C)

        if not training:
            batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
                batch_size=batch_dict['batch_size'], rois=batch_dict['rois'], cls_preds=rcnn_cls, box_preds=rcnn_reg
            )
            batch_dict['batch_cls_preds'] = batch_cls_preds
            batch_dict['batch_box_preds'] = batch_box_preds
            batch_dict['cls_preds_normalized'] = False
        else:
            targets_dict['rcnn_cls'] = rcnn_cls
            targets_dict['rcnn_reg'] = rcnn_reg

            self.forward_ret_dict = targets_dict
        
        return batch_dict

注:center point的运行并不是只涉及到我所理出来的代码哦,还有很多其他的代码。比如这个代码,第一行可以知道class RoIHead(RoIHeadTemplate),它是继承自RoIHeadTemplate这个类的,而这个类的代码位置在:CenterPoint-master\det3d\models\roi_heads\roi_head_template.py所以大家如果想要更加深入的理解代码还是要仔细得看代码,我这里只是给大家提供一个运行的思路方便大家去有目标顺序的看代码。

作用:定义了一个 RoI 池化头部(RoIHead)的类,用于在目标检测任务中对区域提议(RoIs)进行分类和回归,以生成最终的检测结果。这里还涉及到损失值的计算与定义,不过多赘述。

以上,就是centerpoint源码的以nusc_two_stage_base_with_virtual.py作为配置文件的详细运行流程,具体代码内容还需要大家仔细的去理解。

下一篇博客我将会去学习center point的损失函数。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值