数据预处理
数据预处理,这里仅介绍一个shapenetdataset;
class ShapeNetDataset(data.Dataset):
def __init__(self,
root,
npoints=2500,
classification=False,
class_choice=None,
split='train',
data_augmentation=True):
self.npoints = npoints # 单个数据集的点数
self.root = root
self.catfile = os.path.join(self.root,'synsetoffset2category.txt') #各个类别的数据对应的文件夹的路径
self.cat = {
}
self.data_augmentation = data_augmentation # 是否进行数据增强
self.classification = classification # 数据的种类
self.seg_classes = {
}
with open(self.catfile,'r') as f:
for line in f:
ls = line.strip().split()
self.cat[ls[0]] = ls[1]
if not class_choice is None:
self.cat = {
k: v for k,v in self.cat.items() if k in class_choice}
self.id2cat = {
v:k for k,v in self.cat.items()}
self.meta = {
}
# 读取已经分类好的数据的地址
splitfile = os.path.join(self.root,'train_test_split','shuffled_{}_file_list.json'.format(split))
filelist = json.load(open(splitfile,'r'))
for item in self.cat:
self.meta[item] = []
# 数据存储地址的转换
for file in filelist:
_,category,uuid = file.split('/')
if category in self.cat.values():
self.meta[self.id2cat[category]].append((os.path.join(self.root,category,'points',uuid+'.pts'),
os.path.join(self.root, category, 'points_label', uuid+'.seg')))
#按类别存储数据路径
self.datapath = []
for item in self.cat:
for fn in self.meta[item]:
self.datapath.append((item,fn[0],fn[1]))
self.classes = dict(zip(sorted(self.cat),range(len(self.cat))))
print(self.classes)
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)),'../misc/num_seg_classes.txt'),'r') as f:
for line in f:
ls = line.strip().split()
self.seg_classes[ls[0]] = int(ls[1])
self.num_seg_classes = self.seg_classes[list(self.cat.keys())[0]]
print(self.seg_classes,self.num_seg_classes)
#__getitem__ 方法通常用于定制类的实例对象的索引访问行为,使得类的实例可以像序列(如列表、元组)或映射(如字典)一样进行索引操作。
# 在你的代码片段中,这个方法的定义可能是为了支持类实例的索引访问,比如 instance[index] 的操作。
def __getitem__(self, index):
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
point_set = np.loadtxt(fn[1]).astype(np.float32)
seg = np.loadtxt((fn[2])).astype(np.int64)
choice = np.random.choice(len(seg),self.npoints,replace=True)
#resample
point_set = point_set[choice,:]
#去中心化
point_set = point_set - np.expand_dims(np.mean(point_set,axis=0),0)
#单位化
dist = np.max(np.sqrt(np.sum(point_set ** 2,axis = 1)),0)
point_set = point_set / dist
# 采用随机旋转和随机高斯抖动对数据进行数据增强
if self.data_augmentation:
theta = np.random.uniform(0,np.pi*2)
rotation_matrix = np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]])
point_set[:,[0,2]] = point_set[:,[0,2]].dot(rotation_matrix) # 随机旋转
point_set += np.random.normal(0,0.02,size=point_set.shape) # 生成的随机数服从均值为 0,标准差为 0.02 的正态分布
seg = seg[choice]
point_set = torch.from_numpy(point_set)
seg = torch.from_numpy(seg)
cls = torch.from_numpy(np.array([cls]).astype(np.int64))
if self.classification:
return point_set,cls
else:
return point_set,seg
def __len__(self):

最低0.47元/天 解锁文章
6950

被折叠的 条评论
为什么被折叠?



