文章目录
Sparse Convolution成功用于3D目标检测的网络,例如Second,Part-A^2,PV-RCNN等,证明其有效性。相比于3D Convolution,在运算速度和显存消耗中有巨大的优势。Sparse Convolution在SECOND论文中提出,并且原文中给出了实现的方法。但在Part- A 2 A^2 A2和PV-RCNN中用的是另外一个版本,并且在pytorch中非常易用。
本文通过阅读OpenPCDet和Spconv代码,来看看spconv到底是怎么实现的。OpenPCDet主要是用来看spconv如何使用。
最重要的一句话,spconv的readme中写道:“This implementation use gather-gemm-scatter framework to do sparse convolution.”看完本文,对能理解这句话了。
OpenPCDet:VoxelBackBone8x
主要是看看spconv是怎么用的,具体查看OpenPCDet中的pcdet/models/backbones_3d/spconv_backbone.py中的VoxelBackBone8x类别。
class VoxelBackBone8x(nn.Module):
def __init__(self, model_cfg, input_channels, grid_size, **kwargs):
super().__init__()
self.model_cfg = model_cfg
norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)
self.sparse_shape = grid_size[::-1] + [1, 0, 0] # self.sparse_shape=[41, 1600, 1480]
self.conv_input = spconv.SparseSequential(
spconv.SubMConv3d(input_channels, 16, 3, padding=1, bias=False, indice_key='subm1'), # 重点关注
norm_fn(16),
nn.ReLU(),
)
...
def forward(self, batch_dict):
"""
Args:
batch_dict:
batch_size: int
vfe_features: (num_voxels, C)
voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx]
Returns:
batch_dict:
encoded_spconv_tensor: sparse tensor
"""
voxel_features, voxel_coords = batch_dict['voxel_features'], batch_dict['voxel_coords']
batch_size = batch_dict['batch_size']
# 先初始化Sparse Tensor
input_sp_tensor = spconv.SparseConvTensor(
features=voxel_features, # [32000, 3], 其中32000为两帧点云中voxel的数量,3是xyz
indices=voxel_coords.int(), # [32000, 4],voxel的具体坐标,4中的第一个代表是哪一帧,后面是3维中的3个坐标
spatial_shape=self.sparse_shape, # self.sparse_shape=[41, 1600, 1480],由kitti决定
batch_size=batch_size # batch_size=2,有两帧点云
)
x = self.conv_input(input_sp_tensor) # 进行submanifold convolution
...
spconv
Sparse Tensor
在进行稀疏卷积之前,要构建Sparse Tensor,先看看这个是怎么构建的,在spconv中的spconv/__init__.py中:
class SparseConvTensor(object):
def __init__(self, features, indices, spatial_shape, batch_size, grid=None):
"""
Args:
grid: pre-allocated grid tensor. should be used when the volume of spatial shape
is very large.
"""
self.features = features # 储存密集的feature
self.indices = indices # 储存每个feature对应的voxel坐标系下的坐标
if self.indices.dtype