PointNet++代码注释

本文详细介绍了PointNet++的实现细节,包括ShapeNetDataLoader模块用于将点云数据转换为可处理的数组,PointNetSetAbstraction和PointNetFeaturePropagation网络结构,以及关键的点采样和聚类算法。这些内容展示了如何处理定制的点云数据集,并通过PointNet++进行特征学习和上采样操作。

  复现了一下PointNet++,拿自己做的数据集试试感觉效果挺好,于是开始读读代码,打打基础,顺便找点灵感和思路。代码中有些注释是基于我自己制作的数据集进行解释的,我的数据集仿照的是ShapeNet数据集的格式,在我的数据集中,只有一种物体:book(书),book有两个部分:background(背景)和seam(书缝),分别对应0和1。


ShapeNetDataLoader.py

  ShapeNetDataLoader.py作用就是把n个点云数据转换成一个数组,数组有n项,每项包含点的信息,点的大类别(book),点的小类别(background,seam)。

# *_*coding:utf-8 *_*
import os
import json
import warnings
import numpy as np
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')

def pc_normalize(pc):
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
    pc = pc / m
    return pc

class PartNormalDataset(Dataset):
    def __init__(self,root = './data/book_seam_dataset', npoints=50000, split='train', class_choice=None, normal_channel=False):
        self.npoints = npoints
        self.root = root
        self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
        self.cat = {
   
   }
        self.normal_channel = normal_channel


        with open(self.catfile, 'r') as f:
            for line in f:
                ls = line.strip().split()
                self.cat[ls[0]] = ls[1]
        self.cat = {
   
   k: v for k, v in self.cat.items()} # {'book': '12345678'}
        self.classes_original = dict(zip(self.cat, range(len(self.cat)))) # {'book': 0}

        if not class_choice is  None:
            self.cat = {
   
   k:v for k,v in self.cat.items() if k in class_choice}
        # print(self.cat)

        self.meta = {
   
   }
        with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
            train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) # {'1', '2', ...}
        with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
            val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
        with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
            test_ids = set([str(d.split('/')[2]) for d in json.load(f)])
        for item in self.cat:  # item:'book'
            # print('category', item)
            self.meta[item] = []
            dir_point = os.path.join(self.root, self.cat[item])
            fns = sorted(os.listdir(dir_point))
            # print(fns[0][0:-4])
            if split == 'trainval': # 取训练集+验证集
                fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] # fn[0:-4]就是‘1.txt’里面的‘1’, fns:['1.txt', '10.txt', ...]
            elif split == 'train':
                fns = [fn for fn in fns if fn[0:-4] in train_ids]
            elif split == 'val':
                fns = [fn for fn in fns if fn[0:-4] in val_ids]
            elif split == 'test':
                fns = [fn for fn in fns if fn[0:-4] in test_ids]
            else:
                print('Unknown split: %s. Exiting..' % (split))
                exit(-1)

            # print(os.path.basename(fns))
            f
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

a_struggler

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值