写在前面:
对于 BEVFormer 算法框架的整体理解,大家可以找到大量的资料参考,但是对于算法代码的解读缺乏详实的资料。因此,本系列的目的是结合代码实现细节、在 tensor 维度的变换中帮助读者对算法能有更直观的认识。
本系列我们将对 BEVFormer 公版代码(开源算法)进行逐行解析,以结合代码理解 Bevformer 原理,掌握算法细节,帮助读者们利用该算法框架开发感知算法。在本系列的最后笔者还将面向地平线的用户,指出地平线参考算法在开源算法基础上做出的修改及修改背后的考虑,在算法部署过程中为用户提供参考。
公版代码目录封装较好,且以注册器的方式调用模型,各个模块的调用关系可以从 configs/bevformer 中的 config 文件中清晰体现,我们以 bevformer_tiny.py 为例3解析代码,Encoder 部分已经发出,见《BEVFormer 开源算法逐行解析(一):Encoder 部分》,本文主要关注 BEVFormer 的 Decoder 和 Det 部分。
对代码的解析和理解主要体现在代码注释中。
1 PerceptionTransformer:
功能:
- 将 encoder 层输出的 bev_embed 传入 decoder 中
- 将在 BEVFormer 中定义的 query_embedding 按通道拆分成通道数相同的 query_pos 和 query,并传入 decoder 中;
- 利用 query_pos 通过线性层 reference_points 生成 reference_points,并传入 decoder;该 reference_points 在 decoder 中的CustimMSDeformableAttention 作为融合 bev_embed 的基准采样点,作用类似于 two-stage 目标检测中的 Region Proposal ;
- 返回 inter_states, inter_references 给 cls_branches 和 reg_branches 分支得到目标的种类和 bboxes。
解析:
#详见《BEVFormer开源算法逐行解析(一):Encoder部分》,用于获得bev_embed
#在decoder中利用CustimMSDeformableAttention将bev_embed与query融合
bev_embed = self.get_bev_features(
mlvl_feats,
bev_queries,
bev_h,
bev_w,
grid_length=grid_length,
bev_pos=bev_pos,
prev_bev=prev_bev,
**kwargs) # bev_embed shape: bs, bev_h*bev_w, embed_dims
bs = mlvl_feats[0].size(0)
#object_query_embed:torch.Size([900, 512])
#query_pos:torch.Size([900, 256])
#query:torch.Size([900, 256])
query_pos, query = torch.split(
object_query_embed, self.embed_dims, dim=1)
query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1)
query = query.unsqueeze(0).expand(bs, -1, -1)
#reference_points:torch.Size([1, 900, 3])
reference_points = self.reference_points(query_pos)
reference_points = reference_points.sigmoid()
init_reference_out = reference_points
#query:torch.Size([900, 1, 256])
query = query.permute(1, 0, 2)
#query_pos:torch.Size([900, 1, 256])
query_pos = query_pos.permute(1, 0, 2)
#bev_embed:torch.Size([50*50, 1, 256])
bev_embed = bev_embed.permute(1, 0, 2)
#进入decoder模块!
inter_states, inter_references = self.decoder(
query=query,
key=None,
value=bev_embed,
query_pos=query_pos,
reference_points=reference_points,
reg_branches=reg_branches,
cls_branches=cls_branches,
spatial_shapes=torch.tensor([[bev_h, bev_w]], device=query.device),
level_start_index=torch.tensor([0], device=query.device),
**kwargs)
#返回inter_states, inter_references
#后续用于提供给cls_branches和reg_branches分支得到目标的种类和bboxes
inter_references_out = inter_references
return bev_embed, inter_states, init_reference_out, inter_references_out
2 DetectionTransformerDecoder
功能:
- 循环进入6个相同的 DetrTransformerDecoderLayer,一个 DetrTransformerDecoderLayer 包含 (‘self_attn’, ‘norm’, ‘cross_attn’, ‘norm’, ‘ffn’, ‘norm’),每层输出 output 和 reference_points;
- 在6层 DetrTransformerDecoderLayer 遍历完成后,将6层输出的 output 和 reference_points 输出。
解析:
#output:torch.Size([900, 1, 256])
output = query
intermediate = []
intermediate_reference_points = []
#循环进入6个相同的DetrTransformerDecoderLayer模块
for lid, layer in enumerate(self.layers):
#reference_points_input:torch.Size([1, 900, 1, 2])
#该reference_points在decoder中的CustimMSDeformableAttention作为融合bev_embed的基准采样点
reference_points_input = reference_points[..., :2].unsqueeze(
2) # BS NUM_QUERY NUM_LEVEL 2
#进入某一层DetrTransformerDecoderLayer
output = layer(
output,
*args,
reference_points=reference_points_input,
key_padding_mask=key_padding_mask,
**kwargs)
#output:torch.Size([1, 900, 256])
output = output.permute(1, 0, 2)
if reg_branches is not None:
#tmp:torch.Size([1, 900, 10])
tmp = reg_branches[lid](output)
assert reference_points.shape[-1] == 3
#new_reference_pointtorch.Size([1, 900, 3])
new_referenc