复现了一下PointNet++,拿自己做的数据集试试感觉效果挺好,于是开始读读代码,打打基础,顺便找点灵感和思路。代码中有些注释是基于我自己制作的数据集进行解释的,我的数据集仿照的是ShapeNet数据集的格式,在我的数据集中,只有一种物体:book(书),book有两个部分:background(背景)和seam(书缝),分别对应0和1。
ShapeNetDataLoader.py
ShapeNetDataLoader.py作用就是把n个点云数据转换成一个数组,数组有n项,每项包含点的信息,点的大类别(book),点的小类别(background,seam)。
# *_*coding:utf-8 *_*
import os
import json
import warnings
import numpy as np
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
return pc
class PartNormalDataset(Dataset):
def __init__(self,root = './data/book_seam_dataset', npoints=50000, split='train', class_choice=None, normal_channel=False):
self.npoints = npoints
self.root = root
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
self.cat = {
}
self.normal_channel = normal_channel
with open(self.catfile, 'r') as f:
for line in f:
ls = line.strip().split()
self.cat[ls[0]] = ls[1]
self.cat = {
k: v for k, v in self.cat.items()} # {'book': '12345678'}
self.classes_original = dict(zip(self.cat, range(len(self.cat)))) # {'book': 0}
if not class_choice is None:
self.cat = {
k:v for k,v in self.cat.items() if k in class_choice}
# print(self.cat)
self.meta = {
}
with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) # {'1', '2', ...}
with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
test_ids = set([str(d.split('/')[2]) for d in json.load(f)])
for item in self.cat: # item:'book'
# print('category', item)
self.meta[item] = []
dir_point = os.path.join(self.root, self.cat[item])
fns = sorted(os.listdir(dir_point))
# print(fns[0][0:-4])
if split == 'trainval': # 取训练集+验证集
fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] # fn[0:-4]就是‘1.txt’里面的‘1’, fns:['1.txt', '10.txt', ...]
elif split == 'train':
fns = [fn for fn in fns if fn[0:-4] in train_ids]
elif split == 'val':
fns = [fn for fn in fns if fn[0:-4] in val_ids]
elif split == 'test':
fns = [fn for fn in fns if fn[0:-4] in test_ids]
else:
print('Unknown split: %s. Exiting..' % (split))
exit(-1)
# print(os.path.basename(fns))
f

本文详细介绍了PointNet++的实现细节,包括ShapeNetDataLoader模块用于将点云数据转换为可处理的数组,PointNetSetAbstraction和PointNetFeaturePropagation网络结构,以及关键的点采样和聚类算法。这些内容展示了如何处理定制的点云数据集,并通过PointNet++进行特征学习和上采样操作。
最低0.47元/天 解锁文章
8884

被折叠的 条评论
为什么被折叠?



