f attention(self,
query,
value,
reference_points_cam,
weights,
proj_mask,
spatial_shapes,
level_start_index,
):
num_query = query.size(1)
bs, num_cam, num_value = value.shape[:3]
num_all_points = weights.size(-1)
slots = torch.zeros_like(query)
# (bs, num_query, num_head, num_cam, num_level, num_p)
# --> (bs, num_cam, num_query, num_head, num_level, num_p)
weights = weights.permute(0, 3, 1, 2, 4, 5).contiguous()
# save memory trick, similar as bevformer_occ
indexes = [[] for _ in range(bs)]
max_len = 0
for i in range(bs):
for j in range(num_cam):
index_query_per_img = proj_mask[i, j].flatten(1).sum(-1).nonzero().squeeze(-1)
indexes[i].append(index_query_per_img)
max_len = max(max_len, index_query_per_img.numel())
queries_rebatch = query.new_zeros(
[bs, self.num_cams, max_len, self.embed_dims])
reference_points_cam_rebatch = reference_points_cam.new_zeros(
[bs, self.num_cams, max_len, self.num_heads, self.num_levels, num_all_points, 2])
weights_rebatch = weights.new_zeros(
[bs, self.num_cams, max_len, self.num_heads, self.num_levels, num_all_points])
for i in range(bs):
for j in range(num_cam):
index_query_per_img = indexes[i][j]
curr_numel = index_query_per_img.numel()
queries_rebatch[i, j, :curr_numel] = query[i, index_query_per_img]
reference_points_cam_rebatch[i, j, :curr_numel] = reference_points_cam[i, j, index_query_per_img]
weights_rebatch[i, j, :curr_numel] = weights[i, j, index_query_per_img]
value = value.view(bs*num_cam, num_value, self.num_heads, -1)
sampling_locations = reference_points_cam_rebatch.view(bs*num_cam, max_len, self.num_heads, self.num_levels, num_all_points, 2)
attention_weights = weights_rebatch.reshape(bs*num_cam, max_len, self.num_heads, self.num_levels, num_all_points)
if torch.cuda.is_available() and value.is_cuda:
output = MultiScaleDeformableAttnFunction_fp32.apply(
value, spatial_shapes, level_start_index, sampling_locations,
attention_weights, self.im2col_step)
else:
output = multi_scale_deformable_attn_pytorch(
value, spatial_shapes, sampling_locations, attention_weights)
output = output.view(bs, num_cam, max_len, -1)
for i in range(bs):
for j in range(num_cam):
index_query_per_img = indexes[i][j]
slots[i, index_query_per_img] += output[i, j, :len(index_query_per_img)]
return slots
最新发布