1.说明
想要对pointpillar神经网络部分大改,但个人感觉PCdet架构下源代码多层嵌套的循环形式代码的可读性低,用习惯了YOLO的代码看这个很难受,不容易在源代码的基础上进行改进。因此花时间重构了一下网络结构,增加可读性和可编程性。让网络代码可以像YOLO代码一样一层一层的添加新的模块,随意改进新的网络结构,重构代码放在下面。
2. 使用方法
将源代码路径OpenPCDet/pcdet/models/backbones_2d/base_bev_backbone.py文件中的BaseBEVBackbone类全部替换成下列代码(整个类全部替换)
添加图片注释,不超过 140 字(可选)
3.代码逻辑
定义ConvModule类进行3*3卷集。定义deConvModule进行上采样。定义ConvModule2进行多层卷集网络。
在网络class BaseBEVBackbone(nn.Module)中逐层定义网络结构
在前向传播中逐层传播
4.附录代码
class ConvModule(nn.Module): #可变层卷集
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, use_bn=True, use_relu=True):
super(ConvModule, self).__init__()
layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=not use_bn)]
if use_bn:
layers.append(nn.BatchNorm2d(out_channels, eps=1e-3, momentum=0.01))
if use_relu:
layers.append(nn.ReLU(inplace=True))
self.conv = nn.Sequential(*layers)
def forward(self, x):
return self.conv(x)
class deConvModule(nn.Module): #上采样层
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, use_bn=True, use_relu=True):
super(deConvModule, self).__init__()
layers = [nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=not use_bn)]
if use_bn:
layers.append(nn.BatchNorm2d(out_channels, eps=1e-3, momentum=0.01))
if use_relu:
layers.append(nn.ReLU(inplace=True))
self.deconv = nn.Sequential(*layers)
def forward(self, x):
return self.deconv(x)
class ConvModule2(nn.Module): #可变层卷集
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, use_bn=True, use_relu=True, num_1x1_layers=0):
super(ConvModule2, self).__init__()
layers = []
# 添加 3x3 卷积层
layers.append(nn.ZeroPad2d(1))
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=0, bias=not use_bn))
if use_bn:
layers.append(nn.BatchNorm2d(out_channels, eps=1e-3, momentum=0.01))
if use_relu:
layers.append(nn.ReLU(inplace=True))
in_channels = out_channels # 保证下一层输入通道匹配
# 添加 1x1 卷积层
for _ in range(num_1x1_layers):
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=not use_bn))
if use_bn:
layers.append(nn.BatchNorm2d(out_channels, eps=1e-3, momentum=0.01))
if use_relu:
layers.append(nn.ReLU(inplace=True))
in_channels = out_channels
self.conv = nn.Sequential(*layers)
def forward(self, x):
return self.conv(x)
class BaseBEVBackbone(nn.Module): #定义网络
def __init__(self, model_cfg, input_channels):
super().__init__()
self.model_cfg = model_cfg
if self.model_cfg.get('LAYER_NUMS', None) is not None:
assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == len(self.model_cfg.NUM_FILTERS)
layer_nums = self.model_cfg.LAYER_NUMS
layer_strides = self.model_cfg.LAYER_STRIDES
num_filters = self.model_cfg.NUM_FILTERS
else:
layer_nums = layer_strides = num_filters = []
if self.model_cfg.get('UPSAMPLE_STRIDES', None) is not None:
assert len(self.model_cfg.UPSAMPLE_STRIDES) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
num_upsample_filters = self.model_cfg.NUM_UPSAMPLE_FILTERS
upsample_strides = self.model_cfg.UPSAMPLE_STRIDES
else:
upsample_strides = num_upsample_filters = []
c_in_list = [input_channels, *num_filters[:-1]]
idx=0 #网络结构p0
self.c1=ConvModule2(c_in_list[idx],num_filters[idx],stride=layer_strides[idx],padding=0,num_1x1_layers=layer_nums[idx]) #0
self.d1=deConvModule(num_filters[idx], num_upsample_filters[idx],upsample_strides[idx],stride=upsample_strides[idx]) #1
idx=1 #网络结构p1
self.c2=ConvModule2(c_in_list[idx],num_filters[idx],stride=layer_strides[idx],padding=0,num_1x1_layers=layer_nums[idx]) #2
self.d2=deConvModule(num_filters[idx], num_upsample_filters[idx],upsample_strides[idx],stride=upsample_strides[idx]) #3
idx=2 #网络结构p3
self.c3=ConvModule2(c_in_list[idx],num_filters[idx],stride=layer_strides[idx],padding=0,num_1x1_layers=layer_nums[idx]) #4
self.d3=deConvModule(num_filters[idx], num_upsample_filters[idx],upsample_strides[idx],stride=upsample_strides[idx]) #5
c_in = sum(num_upsample_filters)
self.num_bev_features = c_in
def forward(self, data_dict):
spatial_features = data_dict['spatial_features']
ups = []
ret_dict = {}
x = spatial_features
#网络结构p0
x=self.c1(x) #0
y=self.d1 #1
ups.append(y)
#网络结构p0
x=self.c2(x) #2
y=self.d2(x) #3
ups.append(y)
#网络结构p0
x=self.c3(x) #4
y=self.d3(x) #5
ups.append(y)
for i, u in enumerate(ups):
print(f"ups[{i}] shape:", u.shape)
if len(ups) > 1:
x = torch.cat(ups, dim=1) #6
elif len(ups) == 1:
x = ups[0]
data_dict['spatial_features_2d'] = x
return data_dict