写在前面:
对于 BEVFormer 算法框架的整体理解,大家可以找到大量的资料参考,但是对于算法代码的解读缺乏详实的资料。因此,本系列的目的是结合代码实现细节、在 tensor 维度的变换中帮助读者对算法能有更直观的认识。
本系列我们将对 BEVFormer 公版代码(开源算法)进行逐行解析,以结合代码理解 Bevformer 原理,掌握算法细节,帮助读者们利用该算法框架开发感知算法。在本系列的最后笔者还将面向地平线的用户,指出地平线参考算法在开源算法基础上做出的修改及修改背后的考虑,在算法部署过程中为用户提供参考。
公版代码目录封装较好,且以注册器的方式调用模型,各个模块的调用关系可以从 configs/bevformer 中的 config 文件中清晰体现,我们以 bevformer_tiny.py 为例3解析代码。
对代码的解析和理解主要体现在代码注释中。
model = dict(
type='BEVFormer',
use_grid_mask=True,
video_test_mode=True,
pretrained=dict(img='torchvision://resnet50'), #预训练权重
img_backbone=dict(
type='ResNet', #主干网络
....
),
img_neck=dict(
type='FPN', #颈部网络
...
),
pts_bbox_head=dict(
type='BEVFormerHead', #进入transformer
...,
transformer=dict(
type='PerceptionTransformer', #进入transformer
...,
encoder=dict(
type='BEVFormerEncoder', #进入Encoder
...,
transformerlayers=dict(
type='BEVFormerLayer', #进入Encoder中的一层
attn_cfgs=[
dict(
type='TemporalSelfAttention', #时序注意力机制
...,
dict(
type='SpatialCrossAttention', #空间注意力机制
...,
deformable_attention=dict(
type='MSDeformableAttention3D',
...),
...,
)
],
...),
decoder=dict(
type='DetectionTransformerDecoder', #进入Decoder
...,
transformerlayers=dict(
type='DetrTransformerDecoderLayer', #进入Decoder中的一层
attn_cfgs=[
dict(
type='MultiheadAttention', #多头注意力机制
...),
dict(
type='CustomMSDeformableAttention',
...),
],
))),
从上述 config 文件可以看出,6个相机输出的图像在前向传播过程中依次经过了’ResNet’、‘FPN’获得了图像特征,然后经过’BEVFormerHead’模块中的’BEVFormerEncoder’和’DetectionTransformerDecoder’完成了特征融合的全过程。其中’BEVFormerEncoder’包括前后级联的’TemporalSelfAttention’和’SpatialCrossAttention’,这种前后级联的结构在 bevformer_tiny 中一共有3层。
1 BEVFormer:
功能:
- 通过 grid_mask 进行了图像增强;
- 利用 ResNet(backbone)和FPN(neck)两个网络提取图像特征;
- 进入 BEVFormerHead 中。
解析:
#img: (bs,queue=3,num_cams=6,C=3,H=480,W=800)
#按照queue长度将图像分为pre_img和img
len_queue = img.size(1)
prev_img = img[:, :-1, ...] #(1,2,6,3,480,800)
img = img[:, -1, ...] #(1,6,3,480,800)
#利用img_queue中除当前帧之外的前几帧生成BEV_pre
prev_img_metas = copy.deepcopy(img_metas)
prev_bev = self.obtain_history_bev(prev_img, prev_img_metas)
#------------------------obtain_history_bev start--------------------------------------
#利用img_queue中除当前帧之外的前2帧生成BEV_pre,供后续TSA使用
def obtain_history_bev(self, imgs_queue, img_metas_list):
"""Obtain history BEV features iteratively. To save GPU memory, gradients are not calculated.
"""
self.eval()
with torch.no_grad():
prev_bev = None
#imgs_queue:torch.size(1,2,6,3,480,800)
bs, len_queue, num_cams, C, H, W = imgs_queue.shape
#imgs_queue:torch.size(1,6,3,480,800)
imgs_queue = imgs_queue.reshape(bs*len_queue, num_cams, C, H, W)
img_feats_list = self.extract_feat(img=imgs_queue, len_queue=len_queue
#--------------------------------extract_feat start--------------------------------------
#提取图像特征并reshape
@auto_fp16(apply_to=('img'))
def extract_feat(self, img, img_metas=None, len_queue=None):
"""Extract features from images and points."""
img_feats = self.extract_img_feat(img, img_metas, len_queue=len_queue)
#--------------------------------extract_feat end----------------------------------------
#-----------------------------extract_img_feat start---------------------------------------
def extract_img_feat(self, img, img_metas, len_queue=None):
"""Extract features of images."""
# B=batch_size
B = img.size(0)
if img is not None:
if img.dim() == 5 and img.size(0) == 1:
img.squeeze_()
elif img.dim() == 5 and img.size(0) > 1:
B, N, C, H, W = img.size()
img = img.reshape(B * N, C, H, W)
#图像增强的一种手段,利用mask遮挡部分图像,让网络学习目标更多的特征,避免过拟合
if self.use_grid_mask:
#从obtain_history_bev()中进入:
#img:torch.size(12,2,480,800) queue=3-1=2(t-2,t-1时刻两帧用以生成prev_bev)
#在obtain_history_bev后,从extract_feat()中进入:
#img:torch.size(12,2,480,800) queue=3-2=1(t时刻当前帧和prev_bev生成bev_query)
img = self.grid_mask(img)
#从obtain_history_bev()中进入:
#img_feats:tuple(torch.Size(12,2048,15,25)) 12=2*6:queue=3-1=2,cam_num=6
#在obtain_history_bev后,从extract_feat()中进入:
#img_feats:tuple(torch.Size(6,2048,15,25)) 12=1*6:queue=3-2=1,cam_num=6
img_feats = self.img_backbone(img)
if isinstance(img_feats, dict):
img_feats = list(img_feats.values())
else:
return None
if self.with_img_neck:
#同理,img_feats:tuple(torch.Size(12,256,15,25))/tuple(torch.Size(6,256,15,25))
img_feats = self.img_neck(img_feats)
img_feats_reshaped = []
for img_feat in img_feats:
BN, C, H, W = img_feat.size()
#非第一帧(3-1=2)
if len_queue is not None:
#img_feat的形状为:
#torch.size(B/len_queue=1,len_queue=2),num_cams=6,C=256,H=15,W=25)
img_feats_reshaped.append(img_feat.view(int(B/len_queue), len_queue, int(BN / B), C, H, W))
#第一帧(1)
else:
img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))
return img_feats_reshaped
#-----------------------------extract_img_feat end----------------------------------------
for i in range(len_queue):
img_metas = [each[i] for each in img_metas_list]
if not img_metas[0]['prev_bev_exists']:
prev_bev = None
# img_feats = self.extract_feat(img=img, img_metas=img_metas)
# img_feats 按照queue维度进行切片,从6维度降到5维,制作成一个列表
img_feats = [each_scale[:, i] for each_scale in img_feats_list]
#利用img_feats和img_metas生成prev_bev
##从pts_bbox_head进入下一环节BEVFormerHead(包含encoder、decoder)
prev_bev = self.pts_bbox_head(
img_feats, img_metas, prev_bev, only_bev=True)
self.train()
return prev_bev
#------------------------obtain_history_bev end----------------------------------------
img_metas = [each[len_queue-1] for each in img_metas]
if not img_metas[0]['prev_bev_exists']:
prev_bev = None
#提取当前帧图像特征
img_feats = self.extract_feat(img=img, img_metas=img_metas)
#obtain_history_bev用于利用t-2、t-1时刻的图像和img_metas生成pre_bev
#然后将当前帧图像特征img_feats、obtain_history_bev生成的prev_bev和当前图像帧对应的img_metas
#以及bboxes_labels、class_labels输入forward_pts_train计算loss
#在forward_pts_train中进入BEVFormerHaed
losses = dict()
losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d,
gt_labels_3d, img_metas,
gt_bboxes_ignore, prev_bev)
#------------------------------forward_pts_train start-----------------------------------------
def forward_pts_train(self,pts_feats,gt_bboxes_3d,gt_labels_3d,img_metas,gt_bboxes_ignore=None,prev_bev=None):
#从pts_bbox_head进入下一环节BEVFormerHead(包含encoder、decoder)
outs = self.pts_bbox_head(
pts_feats, img_metas, prev_bev)
loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs]
losses = self.pts_bbox_head.loss(*loss_inputs, img_metas=img_metas)
return los