geotransformer.utils.data模块解读
0. 写在前面
-
使用自己的数据进行点云数据配准的训练时, 需要运行GeoTransformer/experiments/geotransformer.kitti.stage5.gse.k3.max.oacl.stage2.sinkhorn/下面的trainval.py,它是调用同目录下的dataset.py来组织数据的。
-
而在这个文件的开头, 就import 了geotransformer.utils.data模块中的三个函数, 因此理解这个模块是非常重要的。
-
首先将dataset.py中的流程在这里简要梳理一下, 方便更好地理解为什么要去解读geotransformer.utils.data模块:
这个流程对应的代码(GeoTransformer/experiments/geotransformer.kitti.stage5.gse.k3.max.oacl.stage2.sinkhorn/dataset.py):
train_dataset = OdometryKittiPairDataset(
cfg.data.dataset_root,
'train',
point_limit=cfg.train.point_limit,
use_augmentation=cfg.train.use_augmentation,
augmentation_noise=cfg.train.augmentation_noise,
augmentation_min_scale=cfg.train.augmentation_min_scale,
augmentation_max_scale=cfg.train.augmentation_max_scale,
augmentation_shift=cfg.train.augmentation_shift,
augmentation_rotation=cfg.train.augmentation_rotation,
)
neighbor_limits = calibrate_neighbors_stack_mode(
train_dataset,
registration_collate_fn_stack_mode,
cfg.backbone.num_stages,
cfg.backbone.init_voxel_size,
cfg.backbone.init_radius,
)
train_loader = build_dataloader_stack_mode(
train_dataset,
registration_collate_fn_stack_mode,
cfg.backbone.num_stages,
cfg.backbone.init_voxel_size,
cfg.backbone.init_radius,
neighbor_limits,
batch_size=cfg.train.batch_size,
num_workers=cfg.train.num_workers,
shuffle=True,
distributed=distributed,
)
1. precompute_data_stack_mode
1.1. 功能
对点云数据进行多阶段的网格下采样(grid subsampling),
在每个阶段中,都开展一次邻域搜索(radius search),以准备训练所需的数据结构。
这里还有一个细节是grid_subsample, 它返回采样后的点,以及点的数量(length)。
这里记录的length很重要。以ref_points为例,在model.py中,由data_dict中每个阶段的length信息, 取最后一次下采样的点记作ref_points_c(coarse,粗糙的), 第一次下采样后的点记作ref_points_f(fine,精细的), 然后在point_to_node_partition函数中,便是以ref_points_c为node的。
1.2. 输入
- points:点云坐标。
- lengths:每个点云中的点数。
- num_stages:下采样的阶段数。
- voxel_size:初始体素大小。
- radius:邻域搜索半径。
- neighbor_limits:各阶段允许的最大邻域点数限制。
1.3. 输出
一个字典,包含各个阶段的点云、长度、邻域信息、下采样和上采样索引。
1.4. 代码
# Stack mode utilities
def precompute_data_stack_mode(points, lengths, num_stages, voxel_size, radius, neighbor_limits):
assert num_stages == len(neighbor_limits)
points_list = []
lengths_list = [