上一篇文章Sparse4Dv3 代码学习(Ⅱ)单帧推理-优快云博客介绍了单帧,也就是序列的第一帧的推理过程,这篇文章主要介绍引入历史帧推理时的处理过程。
①InstanceBank,主要是把缓存的anchor投影到当前帧
# ========= get instance info ============
if (
self.sampler.dn_metas is not None
and self.sampler.dn_metas["dn_anchor"].shape[0] != batch_size
): # 第一帧不进入
self.sampler.dn_metas = None
(
instance_feature,
anchor,
temp_instance_feature,
temp_anchor,
time_interval,
) = self.instance_bank.get(
batch_size, metas, dn_metas=self.sampler.dn_metas
)
第一个很不一样的是实例库,这个时候实例库缓存了历史帧的实例特征和anchor
def get(self, batch_size, metas=None, dn_metas=None):
instance_feature = torch.tile(
self.instance_feature[None], (batch_size, 1, 1) # self.instance_feature: torch.Size([900, 256])
) # torch.Size([1, 900, 256])
anchor = torch.tile(self.anchor[None], (batch_size, 1, 1)) # torch.Size([1, 900, 11])
if (
self.cached_anchor is not None
and batch_size == self.cached_anchor.shape[0]
): # 第一帧时不进入
history_time = self.metas["timestamp"]
time_interval = metas["timestamp"] - history_time # tensor([0.4999], device='cuda:0', dtype=torch.float64)
time_interval = time_interval.to(dtype=instance_feature.dtype)
self.mask = torch.abs(time_interval) <= self.max_time_interval # 0.5<2
第一步是把anchor转换到当前帧:
if self.anchor_handler is not None:
T_temp2cur = self.cached_anchor.new_tensor(
np.stack(
[
x["T_global_inv"]
@ self.metas["img_metas"][i]["T_global"]
for i, x in enumerate(metas["img_metas"])
]
)
) # torch.Size([1, 4, 4])
self.cached_anchor = self.anchor_handler.anchor_projection(
self.cached_anchor,
[T_temp2cur],
time_intervals=[-time_interval],
)[0]
里面的投影的细节,包括中心点(center)的投影和速度(vel)的转换:
@staticmethod # 这应该是帧间anchor的投影转换
def anchor_projection(
anchor,
T_src2dst_list,
src_timestamp=None,
dst_timestamps=None,
time_intervals=None,
):
dst_anchors = []
for i in range(len(T_src2dst_list)):
vel = anchor[..., VX:]
vel_dim = vel.shape[-1]
T_src2dst = torch.unsqueeze(
T_src2dst_list[i].to(dtype=anchor.dtype), dim=1
) # torch.Size([1, 1, 4, 4])
center = anchor[..., [X, Y, Z]]
if time_intervals is not None:
time_interval = time_intervals[i] # tensor([-0.4999], device='cuda:0')
elif src_timestamp is not None and dst_timestamps is not None:
time_interval = (src_timestamp - dst_timestamps[i]).to(
dtype=vel.dtype
)
else:
time_interval = None
if time_interval is not None:
translation = vel.transpose(0, -1) * time_interval
translation = translation.transpose(0, -1)
center = center - translation
center = (
torch.matmul(
T_src2dst[..., :3, :3], center[..., None]
).squeeze(dim=-1)
+ T_src2dst[..., :3, 3]
)
size = anchor[..., [W, L, H]]
yaw = torch.matmul(
T_src2dst[..., :2, :2],
anchor[..., [COS_YAW, SIN_YAW], None],
).squeeze(-1)
vel = torch.matmul(
T_src2dst[..., :vel_dim, :vel_dim], vel[..., None]
).squeeze(-1)
dst_anchor = torch.cat([center, size, yaw, vel], dim=-1)
# TODO: Fix bug
# index = [X, Y, Z, W, L, H, COS_YAW, SIN_YAW] + [VX, VY, VZ][:vel_dim]
# index = torch.tensor(index, device=dst_anchor.device)
# index = torch.argsort(index)
# dst_anchor = dst_anchor.index_select(dim=-1, index=index)
dst_anchors.append(dst_anchor)
return dst_anchors
②历史anchor的编码
if temp_anchor is not None:
temp_anchor_embed = self.anchor_encoder(temp_anchor)
③在refine里面会用到历史anchor的编码temp_anchor_embed:
现在只需要挑出300个anchor:
N = self.num_anchor - self.num_temp_instances # 900-600等于300
confidence = confidence.max(dim=-1).values
_, (selected_feature, selected_anchor) = topk(
confidence, N, instance_feature, anchor
)
根据选择出来的300个anchor合并到缓存的600个anchor一起(包括anchor、实例特征、实例ID):
selected_feature = torch.cat(
[self.cached_feature, selected_feature], dim=1
)
selected_anchor = torch.cat(
[self.cached_anchor, selected_anchor], dim=1
)
instance_feature = torch.where(
self.mask[:, None, None], selected_feature, instance_feature
) # 因为self.mask=True,所以实际上都选择的是selected_feature
anchor = torch.where(self.mask[:, None, None], selected_anchor, anchor)
if self.instance_id is not None:
self.instance_id = torch.where(
self.mask[:, None],
self.instance_id,
self.instance_id.new_tensor(-1),
)
④在gnn里面会用到时间信息
elif op == "temp_gnn":
instance_feature = self.graph_model(
i,
instance_feature, # q
temp_instance_feature, # k
temp_instance_feature, # v
query_pos=anchor_embed,
key_pos=temp_anchor_embed,
attn_mask=attn_mask
if temp_instance_feature is None
else None,
)
def graph_model(
self,
index,
query,
key=None,
value=None,
query_pos=None,
key_pos=None,
**kwargs,
):
if self.decouple_attn: # 进入
query = torch.cat([query, query_pos], dim=-1) # torch.Size([1, 900, 512])
if key is not None: # 第一帧的gnn_temp不进入 后续会进入
key = torch.cat([key, key_pos], dim=-1)
query_pos, key_pos = None, None
if value is not None: # temp_gnn不进入 gnn进入
value = self.fc_before(value) # torch.Size([1, 900, 512]) torch.Size([1, 600, 512])
return self.fc_after(
self.layers[index](
query,
key,
value,
query_pos=query_pos,
key_pos=key_pos,
**kwargs,
)
)
④继续缓存600个:
def cache(
self,
instance_feature,
anchor,
confidence,
metas=None,
feature_maps=None,
):
if self.num_temp_instances <= 0:
return
instance_feature = instance_feature.detach()
anchor = anchor.detach()
confidence = confidence.detach()
self.metas = metas
confidence = confidence.max(dim=-1).values.sigmoid()
if self.confidence is not None:
confidence[:, : self.num_temp_instances] = torch.maximum(
self.confidence * self.confidence_decay, # self.confidence_decay=0.6
confidence[:, : self.num_temp_instances], # 选前600
)
self.temp_confidence = confidence
(
self.confidence,
(self.cached_feature, self.cached_anchor),
) = topk(confidence, self.num_temp_instances, instance_feature, anchor)