【代码阅读】RandLA-Net

本文深入探讨RandLA-Net在 SemanticKitti 数据集上的应用,从数据预处理、网络结构到前向计算和损失函数。详细解析了Encoder的dilated_res_block和random_sample,以及Decoder的nearest_interpolation,揭示了其在3D点云语义分割中的工作原理。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


RandLa-Net是新提出来的针对大场景语义分割的方法,效果拔群。我对该文章的解读可以看我另外 一篇博客,论文作者给出的代码是 TensorFlow版本的。

接下来,我们就看一下这个代码都做了什么操作。以SemanticKitti为例。

由于作者给出的代码是在TensorFlow下的,我改到了Pytorch下面,代码详见我的Github,是实现了SemanticKITTI数据集下的训练。

数据预处理

# utils/data_prepare-semantickitti.py

# line 42-50
points = DP.load_pc_kitti(join(pc_path, scan_id))
labels = DP.load_label_kitti(join(label_path, str(scan_id[:-4]) + '.label'), remap_lut)
sub_points, sub_labels = DP.grid_sub_sampling(points, labels=labels, grid_size=grid_size)
search_tree = KDTree(sub_points)
KDTree_save = join(KDTree_path_out, str(scan_id[:-4]) + '.pkl')
np.save(join(pc_path_out, scan_id)[:-4], sub_points)
np.save(join(label_path_out, scan_id)[:-4], sub_labels)
with open(KDTree_save, 'wb') as f:
    pickle.dump(search_tree, f)

可以看到,上述预处理,是把point和label做了grid sampling,并且生成了一个kdtree保存下来。

dataset

是用main_SemanticKITTI/SemanticKITTI这个类实现的,我们看看这部分做了什么

class SemanticKITTI:
    def __init__(self, test_id):
        ...
        
    # Generate the input data flow
    def get_batch_gen(self, split):
        ...
        def spatially_regular_gen():
            # Generator loop
            # line 72-79
            for i in range(num_per_epoch):
                if split != 'test':
                    cloud_ind = i
                    pc_path = path_list[cloud_ind]
                    pc, tree, labels = self.get_data(pc_path)
                    # crop a small point cloud
                    pick_idx = np.random.choice(len(pc), 1)
                    selected_pc, selected_labels, selected_idx = self.crop_pc(pc, labels, tree, pick_idx)
                ...
        ...
        return gen_func, gen_types, gen_shapes

    def get_data(self, file_path):  #从file_path所指向的文件中读入point,kdtree和label
        ...
        return points, search_tree, labels

    @staticmethod
    def crop_pc(points, labels, search_tree, pick_idx):
        # crop a fixed size point cloud for training
        center_point = points[pick_idx, :].reshape(1, -1)
        select_idx = search_tree.query(center_point, k=cfg.num_points)[1][0]
        select_idx = DP.shuffle_idx(select_idx)
        select_points = points[select_idx]
        select_labels = labels[select_idx]
        return select_points, select_labels, select_idx

    @staticmethod
    
评论 45
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值