本代码将SE模块融入pointpillar骨干网络中,使得每一个卷卷积过后都经过一个SE模块再进行上采样。
1. 完成OpenPCDet安装与环境配置(我的环境为ubuntu22.04、python3.7、pytorch11.3)
安装、测试和训练OpenPCDet:一篇详尽的指南_openpcdet 安装-优快云博客blog.youkuaiyun.com/laukal/article/details/139395806
2.修改代码
-
找到路径OpenPCDet/pcdet/models/backbones_2d中的base_bev_backbone.py文件打开
-
找到代码第8行
-
将该类替换为以下代码:
class BaseBEVBackbone(nn.Module):
def __init__(self, model_cfg, input_channels):
super().__init__()
self.model_cfg = model_cfg
if self.model_cfg.get('LAYER_NUMS', None) is not None:
assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == len(self.model_cfg.NUM_FILTERS)
layer_nums = self.model_cfg.LAYER_NUMS
layer_strides = self.model_cfg.LAYER_STRIDES
num_filters = self.model_cfg.NUM_FILTERS
else:
layer_nums = layer_strides = num_filters = []
if self.model_cfg.get('UPSAMPLE_STRIDES', None) is not None:
assert len(self.model_cfg.UPSAMPLE_STRIDES) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
num_upsample_filters = self.model_cfg.NUM_UPSAMPLE_FILTERS
upsample_strides = self.model_cfg.UPSAMPLE_STRIDES
else:
upsample_strides = num_upsample_filters = []
num_levels = len(layer_nums)
c_in_list = [input_channels, *num_filters[:-1]]
self.blocks = nn.ModuleList()
self.deblocks = nn.ModuleList()
self.SE = nn.ModuleList()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
for idx in range(num_levels):
cur_layers = [
nn.ZeroPad2d(1),
nn.Conv2d(
c_in_list[idx], num_filters[idx], kernel_size=3,
stride=layer_strides[idx], padding=0, bias=False
),
nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
]
for k in range(layer_nums[idx]):
cur_layers.extend([
nn.Conv2d(num_filters[idx], num_filters[idx], kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
])
self.blocks.append(nn.Sequential(*cur_layers))
self.SE.append(nn.Sequential(
nn.Linear(num_filters[idx], num_filters[idx] //4,bias=False),
nn.ReLU(inplace=True),nn.Linear(num_filters[idx]// 4, num_filters[idx], bias=False),
nn.Sigmoid()))
if len(upsample_strides) > 0:
stride = upsample_strides[idx]
if stride > 1 or (stride == 1 and not self.model_cfg.get('USE_CONV_FOR_NO_STRIDE', False)):
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(
num_filters[idx], num_upsample_filters[idx],
upsample_strides[idx],
stride=upsample_strides[idx], bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))
else:
stride = np.round(1 / stride).astype(np.int)
self.deblocks.append(nn.Sequential(
nn.Conv2d(
num_filters[idx], num_upsample_filters[idx],
stride,
stride=stride, bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))
c_in = sum(num_upsample_filters)
if len(upsample_strides) > num_levels:
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(c_in, c_in, upsample_strides[-1], stride=upsample_strides[-1], bias=False),
nn.BatchNorm2d(c_in, eps=1e-3, momentum=0.01),
nn.ReLU(),
))
self.num_bev_features = c_in
def forward(self, data_dict):
"""
Args:
data_dict:
spatial_features
Returns:
"""
spatial_features = data_dict['spatial_features']
ups = []
ret_dict = {}
x = spatial_features
for i in range(len(self.blocks)):
x = self.blocks[i](x)
stride = int(spatial_features.shape[2] / x.shape[2])
ret_dict['spatial_features_%dx' % stride] = x
if len(self.deblocks) > 0:
ups.append(self.deblocks[i](x))
else:
ups.append(x)
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.SE[i](y).view(b, c, 1, 1)
x = x*y
if len(ups) > 1:
x = torch.cat(ups, dim=1)
elif len(ups) == 1:
x = ups[0]
if len(self.deblocks) > len(self.blocks):
x = self.deblocks[-1](x)
data_dict['spatial_features_2d'] = x
return data_dict
3. 开始训练
//演示文件
python demo.py --cfg_file cfgs/kitti_models/pointpillar.yaml --ckpt ../output/kitti_models/pointpillar/default/ckpt/latest_model.pth --data_path ../data/kitti/testing/velodyne/000034.bin
//测试文件
cd tools
python test.py --cfg_file cfgs/kitti_models/pointpillar.yaml --batch_size 4 --ckpt ../output/kitti_models/pointpillar/default/ckpt/latest_model.pth
//训练文件
cd tools
python train.py --cfg_file ./cfgs/kitti_models/pointpillar.yaml
4. 模型中SE模块已经被加载