Sparse4Dv3 代码学习(Ⅱ)单帧推理

单卡从这里进入推理:

    elif not distributed:   # 进入
        model = MMDataParallel(model, device_ids=[0])
        outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)

特征提取和准备工作

 提取特征:

    def simple_test(self, img, **data):
        feature_maps = self.extract_feat(img)   # img: torch.Size([1, 6, 3, 256, 704])
        else:   # 进入
            feature_maps = self.img_backbone(img)

四种尺度的特征图(从torch.Size([6, 256, 64, 176]) 到 torch.Size([6, 2048, 8, 22]))

经过FPN:

        if self.img_neck is not None:
            feature_maps = list(self.img_neck(feature_maps))

四种尺度的特征图,特征通道都是256(从torch.Size([6, 256, 64, 176]) 到 torch.Size([6, 256, 8, 22]))

对后面可变形注意力做准备工作:

        if self.use_deformable_func:    # 进入
            feature_maps = feature_maps_format(feature_maps)

 把特征图的长和宽合并为一维:

    bs, num_cams = feature_maps[0].shape[:2]
    spatial_shape = []

    col_feats = []
    for i, feat in enumerate(feature_maps):
        spatial_shape.append(feat.shape[-2:])
        col_feats.append(
            torch.reshape(feat, (bs, num_cams, feat.shape[2], -1))
        )

四个特征图合并,并交换顺序:(89760=【64*176+32*88+16*44+8*22】*6个视图)

col_feats = torch.cat(col_feats, dim=-1).permute(0, 1, 3, 2).flatten(1, 2)  # torch.Size([1, 89760, 256])

复制6份:

spatial_shape = [spatial_shape] * num_cams

每层特征开始的位置:

scale_start_index = spatial_shape[..., 0] * spatial_shape[..., 1]

    scale_start_index = scale_start_index.flatten().cumsum(dim=0)
    scale_start_index = torch.cat(
        [torch.tensor([0]).to(scale_start_index), scale_start_index[:-1]]
    )

分成6个视角:

scale_start_index = scale_start_index.reshape(num_cams, -1)

此时显存:

进入Sparse4DHead:

model_outs = self.head(feature_maps, data)

 进入InstanceBank:

        (
            instance_feature,
            anchor,
            temp_instance_feature,
            temp_anchor,
            time_interval,
        ) = self.instance_bank.get(
            batch_size, metas, dn_metas=self.sampler.dn_metas
        )

metas是一些基础信息

 self.instance_feature是初始化为0的nn.Parameter()

    def get(self, batch_size, metas=None, dn_metas=None):
        instance_feature = torch.tile(
            self.instance_feature[None], (batch_size, 1, 1) # self.instance_feature: torch.Size([900, 256])
        )

anchor是900个的初始化的锚框:

anchor = torch.tile(self.anchor[None], (batch_size, 1, 1))  # torch.Size([1, 900, 11])
        else:   # 进入
            self.reset()
            time_interval = instance_feature.new_tensor(
                [self.default_time_interval] * batch_size
            )

anchor特征映射编码

anchor_embed = self.anchor_encoder(anchor)

这里的box_3d就是前面的 anchor:torch.Size([1, 900, 11])


    def forward(self, box_3d: torch.Tensor):
        pos_feat = self.pos_fc(box_3d[..., [X, Y, Z]])  # torch.Size([1, 900, 128])
        size_feat = self.size_fc(box_3d[..., [W, L, H]])    # torch.Size([1, 900, 32])
        yaw_feat = self.yaw_fc(box_3d[..., [SIN_YAW, COS_YAW]]) # torch.Size([1, 900, 32])

四层LN:

concat:  最后输出的特征向量是torch.Size([1, 900, 256])

        elif self.mode == "cat":    # 进入
            output = torch.cat([pos_feat, size_feat, yaw_feat], dim=-1) # torch.Size([1, 900, 192])

        if self.vel_dims > 0:   # 有速度,进入
            vel_feat = self.vel_fc(box_3d[..., VX : VX + self.vel_dims])    # torch.Size([1, 900, 64])
            if self.mode == "add":
                output = output + vel_feat
            elif self.mode == "cat":
                output = torch.cat([output, vel_feat], dim=-1)  # torch.Size([1, 900, 256])

 具体步骤:(一共39小步)

可以分成组:第一组(00:'deformable';01:'ffn';02:'norm';03:'refine';04:'temp_gnn';05:'gnn';06:'norm')第二组(同上)第三组(同上)第四组(同上)第五组(同上)第六组(35:'deformable';36:'ffn';37:'norm';38:'refine'  少了最后三个步骤)

①'deformable' (DeformableFeatureAggregation)
            elif op == "deformable":
                instance_feature = self.layers[i](
                    instance_feature,
                    anchor,
                    anchor_embed,
                    feature_maps,
                    metas,
                )
生成关键点:(SparseBox3DKeyPointsGenerator)
key_points = self.kps_generator(anchor, instance_feature)

7个固定的点:

        size = anchor[..., None, [W, L, H]].exp()   # torch.Size([1, 900, 1, 3])
        key_points = self.fix_scale * size  # torch.Size([1, 900, 7, 3])  7个点?

 6个通过学习到的点:

        if self.num_learnable_pts > 0 and instance_feature is not None: # 进入
            learnable_scale = ( # torch.Size([1, 900, 6, 3]) 6个点?
                self.learnable_fc(instance_feature)
                .reshape(bs, num_anchor, self.num_learnable_pts, 3)
                .sigmoid()
                - 0.5
            )
            key_points = torch.cat(
                [key_points, learnable_scale * size], dim=-2    # 合并成13维 torch.Size([1, 900, 13, 3])
            )

加上旋转:(一共13个点)

        rotation_mat = anchor.new_zeros([bs, num_anchor, 3, 3]) # torch.Size([1, 900, 3, 3])

        rotation_mat[:, :, 0, 0] = anchor[:, :, COS_YAW]
        rotation_mat[:, :, 0, 1] = -anchor[:, :, SIN_YAW]
        rotation_mat[:, :, 1, 0] = anchor[:, :, SIN_YAW]
        rotation_mat[:, :, 1, 1] = anchor[:, :, COS_YAW]
        rotation_mat[:, :, 2, 2] = 1

        key_points = torch.matmul(
            rotation_mat[:, :, None], key_points[..., None]
        ).squeeze(-1)   # torch.Size([1, 900, 13, 3])  旋转
        key_points = key_points + anchor[..., None, [X, Y, Z]]  # torch.Size([1, 900, 13, 3])
获得关键点对应的权重
weights = self._get_weights(instance_feature, anchor_embed, metas)

相机投影矩阵编码:

    def _get_weights(self, instance_feature, anchor_embed, metas=None):
        bs, num_anchor = instance_feature.shape[:2]
        feature = instance_feature + anchor_embed
        if self.camera_encoder is not None: # 进入
            camera_embed = self.camera_encoder(
                metas["projection_mat"][:, :, :3].reshape(
                    bs, self.num_cams, -1
                )
            )
            feature = feature[:, :, None] + camera_embed[:, None]

相机编码,6个相机,每个相机都是4*3(12)的参数:

实例的特征加上相机编码特征,组成的features算出权重,256特征到416(4个特征图尺度*13个点*8个注意力分组)个权重:

        weights = (
            self.weights_fc(feature)
            .reshape(bs, num_anchor, -1, self.num_groups)
            .softmax(dim=-2)
            .reshape(
                bs,
                num_anchor,
                self.num_cams,
                self.num_levels,
                self.num_pts,
                self.num_groups,
            )
        )        # torch.Size([1, 900, 6, 4, 13, 8])
算出关键点在图像上的投影像素坐标:
        if self.use_deformable_func:
            points_2d = (
                self.project_points(
                    key_points,
                    metas["projection_mat"],
                    metas.get("image_wh"),
                )
                .permute(0, 2, 3, 1, 4)
                .reshape(bs, num_anchor, self.num_pts, self.num_cams, 2)
            )

里面的细节,主要过程是坐标投影,把900个锚框的13个关键点投影到6个视图上,得到投影点的像素坐标:

    @staticmethod
    def project_points(key_points, projection_mat, image_wh=None):
        bs, num_anchor, num_pts = key_points.shape[:3]  # torch.Size([1, 900, 13, 3])

        pts_extend = torch.cat(
            [key_points, torch.ones_like(key_points[..., :1])], dim=-1
        )   # torch.Size([1, 900, 13, 4])  在右边新增全1的一列(应该是生成齐次坐标)
        points_2d = torch.matmul(   # torch.Size([1, 6, 900, 13, 4]) points_2d应该是投影到6个视图的13个点的坐标
            projection_mat[:, :, None, None], pts_extend[:, None, ..., None]
        ).squeeze(-1)   # projection_mat:torch.Size([1, 6, 4, 4])
        points_2d = points_2d[..., :2] / torch.clamp(
            points_2d[..., 2:3], min=1e-5
        )   # 除以深度
        if image_wh is not None:
            points_2d = points_2d / image_wh[:, :, None, None]  # 坐标的归一化  torch.Size([1, 6, 900, 13, 2])
        return points_2d
真正的可变形注意力部分
            features = DAF(*feature_maps, points_2d, weights).reshape(
                bs, num_anchor, self.embed_dims
            )

里面就是算子,最后输出这900个框的256维特征:

class DeformableAggregationFunction(Function):
    @staticmethod
    def forward(
        ctx,
        mc_ms_feat,
        spatial_shape,
        scale_start_index,
        sampling_location,
        weights,
    ):
        # output: [bs, num_pts, num_embeds]
        mc_ms_feat = mc_ms_feat.contiguous().float()
        spatial_shape = spatial_shape.contiguous().int()
        scale_start_index = scale_start_index.contiguous().int()
        sampling_location = sampling_location.contiguous().float()
        weights = weights.contiguous().float()
        output = deformable_aggregation_ext.deformable_aggregation_forward(
            mc_ms_feat,
            spatial_shape,
            scale_start_index,
            sampling_location,
            weights,
        )   # torch.Size([1, 900, 256])
        ctx.save_for_backward(
            mc_ms_feat,
            spatial_shape,
            scale_start_index,
            sampling_location,
            weights,
        )
        return output
output = self.proj_drop(self.output_proj(features)) # torch.Size([1, 900, 256])

还要和原来的instance_feature  拼接一下,256+256=512维的特征向量

        elif self.residual_mode == "cat":
            output = torch.cat([output, instance_feature], dim=-1)  # torch.Size([1, 900, 512])
②ffn (AsymmetricFFN)

具体过程:(好像就是一个残差连接)

    def forward(self, x, identity=None):
        if self.pre_norm is not None:
            x = self.pre_norm(x)    # torch.Size([1, 900, 512])
        out = self.layers(x)    # torch.Size([1, 900, 256]) 又回到256维
        if not self.add_identity:   # 不进入
            return self.dropout_layer(out)
        if identity is None:    # 进入
            identity = x
        identity = self.identity_fc(identity)   # torch.Size([1, 900, 256])
        return identity + self.dropout_layer(out)   # 残差连接?
③norm

④refine  (SparseBox3DRefinementModule)

通过特征生成对anchor的修正量(前8维):

        feature = instance_feature + anchor_embed   # 现在这个实例特征是经过了投影到图像之后的图像实例特征  torch.Size([1, 900, 256])
        output = self.layers(feature)   # torch.Size([1, 900, 11]) 256下降到11维
        output[..., self.refine_state] = (
            output[..., self.refine_state] + anchor[..., self.refine_state]
        )   # self.refine_state有从0到7 一共8个  修正量加上anchor本身的值等于输出的值

通过除以帧间时间间隔得到三个方向上速度的修正量(后三维):

        if self.output_dim > 8: # 进入
            if not isinstance(time_interval, torch.Tensor): # 不进入
                time_interval = instance_feature.new_tensor(time_interval)
            translation = torch.transpose(output[..., VX:], 0, -1)
            velocity = torch.transpose(translation / time_interval, 0, -1)  # 三个方向上的速度torch.Size([1, 900, 3]) time_interval是帧间隔时间?
            output[..., VX:] = velocity + anchor[..., VX:]  # 加上修正

10个目标类别的预测:

        if return_cls:  # 进入
            assert self.with_cls_branch, "Without classification layers !!!"
            cls = self.cls_layers(instance_feature) # torch.Size([1, 900, 10])

质量的预测:

        if return_cls and self.with_quality_estimation: # 进入
            quality = self.quality_layers(feature)  # torch.Size([1, 900, 2])

对应论文:

保存预测框、类别和质量:
                prediction.append(anchor)
                classification.append(cls)
                quality.append(qt)
更新实例库:(InstanceBank)
                if len(prediction) == self.num_single_frame_decoder:    # 1
                    instance_feature, anchor = self.instance_bank.update(
                        instance_feature, anchor, cls
                    )   # 更新实例库

# 如果是第一帧self.cached_feature就是None

    def update(self, instance_feature, anchor, confidence):
        if self.cached_feature is None: # 如果是第一帧self.cached_feature就是None
            return instance_feature, anchor
如果不到最后一步,anchor必须再次编码:
                if i != len(self.operation_order) - 1:
                    anchor_embed = self.anchor_encoder(anchor)  # anchor继续编码
 ⑤temp_gnn

            elif op == "temp_gnn":
                instance_feature = self.graph_model(
                    i,
                    instance_feature,
                    temp_instance_feature,
                    temp_instance_feature,
                    query_pos=anchor_embed,
                    key_pos=temp_anchor_embed,
                    attn_mask=attn_mask
                    if temp_instance_feature is None
                    else None,
                )

对应论文:

如引言中所述,我们对Sparse4Dv2中的锚编码器、自我注意和时间交叉注意进行了简单的改进。该体系结构如图5所示。设计原则是以串联的方式组合来自不同模式的特征,而不是使用加法方法。与条件DETR[33]相比,有一些差异。首先,我们对查询之间的注意进行了改进,而不是查询和图像特征之间的交叉注意;交叉注意仍然利用来自Sparse4D的可变形聚合。此外,我们没有在多头注意力级别连接位置嵌入和查询特征,而是在多头注意力级别外部进行修改,为神经网络提供更大的灵活性。

图 5:锚编码器和注意力的架构。我们独立地对锚的多个组件进行高维特征编码,然后将它们连接起来。与原始 Sparse4D 相比,这种方法会导致更低的计算量和参数开销。E 和 F 分别表示锚嵌入和实例特征。

⑥gnn

gnn和之前的temp_gnn其实结构都一样

最终,经过6组上述步骤之后,更新预测结果:

        output.update(
            {
                "classification": classification,
                "prediction": prediction,
                "quality": quality,
            }

classification和quality都只有两个不是None:

prediction是6层都有输出:

缓存当前的实例特征

        # cache current instances for temporal modeling
        self.instance_bank.cache(
            instance_feature, anchor, cls, metas, feature_maps
        )

缓存的时候就不保存梯度了:

        instance_feature = instance_feature.detach()
        anchor = anchor.detach()
        confidence = confidence.detach()

生成置信度:

        confidence = confidence.max(dim=-1).values.sigmoid()
        if self.confidence is not None:
            confidence[:, : self.num_temp_instances] = torch.maximum(
                self.confidence * self.confidence_decay,
                confidence[:, : self.num_temp_instances],
            )
        self.temp_confidence = confidence

选出置信度最高的600个实例特征和锚框:

        (
            self.confidence,
            (self.cached_feature, self.cached_anchor),
        ) = topk(confidence, self.num_temp_instances, instance_feature, anchor)

具体过程:

def topk(confidence, k, *inputs):
    bs, N = confidence.shape[:2]
    confidence, indices = torch.topk(confidence, k, dim=1)  # 选出600个置信度最高的
    indices = (
        indices + torch.arange(bs, device=indices.device)[:, None] * N
    ).reshape(-1)
    outputs = []
    for input in inputs:
        outputs.append(input.flatten(end_dim=1)[indices].reshape(bs, k, -1))# 按照索引,把实例特征和锚框整理好
    return confidence, outputs

推理的最后会得到实例的ID

        if not self.training:
            instance_id = self.instance_bank.get_instance_id(
                cls, anchor, self.decoder.score_threshold
            )
            output["instance_id"] = instance_id

初始化实例ID:

    def get_instance_id(self, confidence, anchor=None, threshold=None):
        confidence = confidence.max(dim=-1).values.sigmoid()
        instance_id = confidence.new_full(confidence.shape, -1).long()  # 初始化实例ID

筛选出新的实例,给它们赋上新的ID:

        mask = instance_id < 0
        if threshold is not None:
            mask = mask & (confidence >= threshold)
        num_new_instance = mask.sum()
        new_ids = torch.arange(num_new_instance).to(instance_id) + self.prev_id
        instance_id[torch.where(mask)] = new_ids

 准备传给下一帧的实例:

        if self.num_temp_instances > 0:
            self.update_instance_id(instance_id, confidence)

具体过程:

    def update_instance_id(self, instance_id=None, confidence=None):
        if self.temp_confidence is None:
            if confidence.dim() == 3:  # bs, num_anchor, num_cls
                temp_conf = confidence.max(dim=-1).values
            else:  # bs, num_anchor
                temp_conf = confidence
        else:
            temp_conf = self.temp_confidence
        instance_id = topk(temp_conf, self.num_temp_instances, instance_id)[1][
            0
        ]   # torch.Size([1, 600, 1])  从900个里面再筛选出600个置信度最高的ID
        instance_id = instance_id.squeeze(dim=-1)
        self.instance_id = F.pad(
            instance_id,
            (0, self.num_anchor - self.num_temp_instances),
            value=-1,
        )   # 剩下的300个用-1补齐到900个实例

后处理

Sparse4DHead结束之后就开始后处理了:

results = self.head.post_process(model_outs)

从10个类别中筛选出置信度最高的那一类,再筛选出300个最高的实例:

        cls_scores = cls_scores[output_idx].sigmoid()   # torch.Size([1, 900, 10])

        if squeeze_cls: # 进入
            cls_scores, cls_ids = cls_scores.max(dim=-1)    # torch.Size([1, 900]) 从10个类别中选出置信度最高的那一个类别
            cls_scores = cls_scores.unsqueeze(dim=-1)   # torch.Size([1, 900, 1])

        box_preds = box_preds[output_idx]   # torch.Size([1, 900, 11])
        bs, num_pred, num_cls = cls_scores.shape
        cls_scores, indices = cls_scores.flatten(start_dim=1).topk(
            self.num_output, dim=1, sorted=self.sorted
        )   # torch.Size([1, 300])  只输出300个框

用centerness再对置信度加权:

        if qulity is not None:  # 进入
            centerness = qulity[output_idx][..., CNS]   # torch.Size([1, 900])
            centerness = torch.gather(centerness, 1, indices // num_cls)    # torch.Size([1, 300])
            cls_scores_origin = cls_scores.clone()
            cls_scores *= centerness.sigmoid()  # torch.Size([1, 300])  乘上centerness
            cls_scores, idx = torch.sort(cls_scores, dim=1, descending=True)    # 再排序
            if not squeeze_cls: # 不进入
                cls_ids = torch.gather(cls_ids, 1, idx)
            if self.score_threshold is not None:    # 不进入
                mask = torch.gather(mask, 1, idx)
            indices = torch.gather(indices, 1, idx)

把框框拿出来:

        for i in range(bs):
            category_ids = cls_ids[i]   # torch.Size([900])
            if squeeze_cls:
                category_ids = category_ids[indices[i]] # torch.Size([300])
            scores = cls_scores[i]  # torch.Size([300])
            box = box_preds[i, indices[i] // num_cls]   # torch.Size([300, 11])

解码框框,得到10维的3D框:

    def decode_box(self, box):
        yaw = torch.atan2(box[:, SIN_YAW], box[:, COS_YAW]) # torch.Size([300])
        box = torch.cat(
            [
                box[:, [X, Y, Z]],
                box[:, [W, L, H]].exp(),
                yaw[:, None],
                box[:, VX:],
            ],
            dim=-1,
        )   # torch.Size([300, 10])
        return box

最后输出有5个东西,3D框、3D置信度、类别、分类置信度、实例ID(都是300个):

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值