Sparse4Dv3 代码学习(Ⅲ)时序多帧推理

上一篇文章Sparse4Dv3 代码学习(Ⅱ)单帧推理-优快云博客介绍了单帧,也就是序列的第一帧的推理过程,这篇文章主要介绍引入历史帧推理时的处理过程。

①InstanceBank,主要是把缓存的anchor投影到当前帧

        # ========= get instance info ============
        if (
            self.sampler.dn_metas is not None
            and self.sampler.dn_metas["dn_anchor"].shape[0] != batch_size
        ):  # 第一帧不进入
            self.sampler.dn_metas = None
        (
            instance_feature,
            anchor,
            temp_instance_feature,
            temp_anchor,
            time_interval,
        ) = self.instance_bank.get(
            batch_size, metas, dn_metas=self.sampler.dn_metas
        )

第一个很不一样的是实例库,这个时候实例库缓存了历史帧的实例特征和anchor

    def get(self, batch_size, metas=None, dn_metas=None):
        instance_feature = torch.tile(
            self.instance_feature[None], (batch_size, 1, 1) # self.instance_feature: torch.Size([900, 256])
        )   # torch.Size([1, 900, 256])
        anchor = torch.tile(self.anchor[None], (batch_size, 1, 1))  # torch.Size([1, 900, 11])

        if (
            self.cached_anchor is not None
            and batch_size == self.cached_anchor.shape[0]
        ):  # 第一帧时不进入
            history_time = self.metas["timestamp"]
            time_interval = metas["timestamp"] - history_time   # tensor([0.4999], device='cuda:0', dtype=torch.float64)
            time_interval = time_interval.to(dtype=instance_feature.dtype)
            self.mask = torch.abs(time_interval) <= self.max_time_interval  # 0.5<2

第一步是把anchor转换到当前帧:

            if self.anchor_handler is not None:
                T_temp2cur = self.cached_anchor.new_tensor(
                    np.stack(
                        [
                            x["T_global_inv"]
                            @ self.metas["img_metas"][i]["T_global"]
                            for i, x in enumerate(metas["img_metas"])
                        ]
                    )
                )   # torch.Size([1, 4, 4])
                self.cached_anchor = self.anchor_handler.anchor_projection(
                    self.cached_anchor,
                    [T_temp2cur],
                    time_intervals=[-time_interval],
                )[0]

里面的投影的细节,包括中心点(center)的投影和速度(vel)的转换:

    @staticmethod   # 这应该是帧间anchor的投影转换
    def anchor_projection(
        anchor,
        T_src2dst_list,
        src_timestamp=None,
        dst_timestamps=None,
        time_intervals=None,
    ):
        dst_anchors = []
        for i in range(len(T_src2dst_list)):
            vel = anchor[..., VX:]
            vel_dim = vel.shape[-1]
            T_src2dst = torch.unsqueeze(
                T_src2dst_list[i].to(dtype=anchor.dtype), dim=1
            )   # torch.Size([1, 1, 4, 4])

            center = anchor[..., [X, Y, Z]]
            if time_intervals is not None:  
                time_interval = time_intervals[i]   # tensor([-0.4999], device='cuda:0')
            elif src_timestamp is not None and dst_timestamps is not None:
                time_interval = (src_timestamp - dst_timestamps[i]).to(
                    dtype=vel.dtype
                )
            else:
                time_interval = None
            if time_interval is not None:
                translation = vel.transpose(0, -1) * time_interval
                translation = translation.transpose(0, -1)
                center = center - translation
            center = (
                torch.matmul(
                    T_src2dst[..., :3, :3], center[..., None]
                ).squeeze(dim=-1)
                + T_src2dst[..., :3, 3]
            )
            size = anchor[..., [W, L, H]]
            yaw = torch.matmul(
                T_src2dst[..., :2, :2],
                anchor[..., [COS_YAW, SIN_YAW], None],
            ).squeeze(-1)
            vel = torch.matmul(
                T_src2dst[..., :vel_dim, :vel_dim], vel[..., None]
            ).squeeze(-1)
            dst_anchor = torch.cat([center, size, yaw, vel], dim=-1)
            # TODO: Fix bug
            # index = [X, Y, Z, W, L, H, COS_YAW, SIN_YAW] + [VX, VY, VZ][:vel_dim]
            # index = torch.tensor(index, device=dst_anchor.device)
            # index = torch.argsort(index)
            # dst_anchor = dst_anchor.index_select(dim=-1, index=index)
            dst_anchors.append(dst_anchor)
        return dst_anchors

②历史anchor的编码

        if temp_anchor is not None:
            temp_anchor_embed = self.anchor_encoder(temp_anchor)

 ③在refine里面会用到历史anchor的编码temp_anchor_embed:

现在只需要挑出300个anchor:

        N = self.num_anchor - self.num_temp_instances   # 900-600等于300
        confidence = confidence.max(dim=-1).values
        _, (selected_feature, selected_anchor) = topk(
            confidence, N, instance_feature, anchor
        )

根据选择出来的300个anchor合并到缓存的600个anchor一起(包括anchor、实例特征、实例ID):

        selected_feature = torch.cat(
            [self.cached_feature, selected_feature], dim=1
        )
        selected_anchor = torch.cat(
            [self.cached_anchor, selected_anchor], dim=1
        )
        instance_feature = torch.where(
            self.mask[:, None, None], selected_feature, instance_feature
        )   # 因为self.mask=True,所以实际上都选择的是selected_feature
        anchor = torch.where(self.mask[:, None, None], selected_anchor, anchor)
        if self.instance_id is not None:
            self.instance_id = torch.where(
                self.mask[:, None],
                self.instance_id,
                self.instance_id.new_tensor(-1),
            )

④在gnn里面会用到时间信息

            elif op == "temp_gnn":
                instance_feature = self.graph_model(
                    i,
                    instance_feature,   # q
                    temp_instance_feature,  # k
                    temp_instance_feature,  # v
                    query_pos=anchor_embed,
                    key_pos=temp_anchor_embed,
                    attn_mask=attn_mask
                    if temp_instance_feature is None
                    else None,
                )
    def graph_model(
        self,
        index,
        query,
        key=None,
        value=None,
        query_pos=None,
        key_pos=None,
        **kwargs,
    ):
        if self.decouple_attn:  # 进入
            query = torch.cat([query, query_pos], dim=-1)   # torch.Size([1, 900, 512])
            if key is not None: # 第一帧的gnn_temp不进入  后续会进入
                key = torch.cat([key, key_pos], dim=-1)
            query_pos, key_pos = None, None
        if value is not None:   # temp_gnn不进入  gnn进入
            value = self.fc_before(value)   # torch.Size([1, 900, 512]) torch.Size([1, 600, 512])
        return self.fc_after(
            self.layers[index](
                query,
                key,
                value,
                query_pos=query_pos,
                key_pos=key_pos,
                **kwargs,
            )
        )

④继续缓存600个:

    def cache(
        self,
        instance_feature,
        anchor,
        confidence,
        metas=None,
        feature_maps=None,
    ):
        if self.num_temp_instances <= 0:
            return
        instance_feature = instance_feature.detach()
        anchor = anchor.detach()
        confidence = confidence.detach()

        self.metas = metas
        confidence = confidence.max(dim=-1).values.sigmoid()
        if self.confidence is not None:
            confidence[:, : self.num_temp_instances] = torch.maximum(
                self.confidence * self.confidence_decay,    # self.confidence_decay=0.6
                confidence[:, : self.num_temp_instances],   # 选前600
            )
        self.temp_confidence = confidence

        (
            self.confidence,
            (self.cached_feature, self.cached_anchor),
        ) = topk(confidence, self.num_temp_instances, instance_feature, anchor)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值