provider.py文件主要为PointNet提供数据加载以及点云预处理等功能
其import如下:
import os
import sys
import numpy as np
import h5py
紧接着import的是对数据目录的一些处理:
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
其中os.path.abspath(__file__)
获取当前文件的绝对路径,例如”E:\test\provider.py“
,os.path.dirname()
则将该文件的绝对路径中的文件名取出,BASE_DIR = "E:\test"
,最后将其加入系统路径中,然后进行点云的下载,代码如下:
""" Download dataset for point cloud classification"""
DATA_DIR = os.path.join(BASE_DIR, 'data')
if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR)
if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')):
www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip'
zipfile = os.path.basename(www)
os.system('wget %s; unzip %s' % (www, zipfile))
os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
os.system('rm %s' % (zipfile))
第一行通过将BASE_DIR
和'data'
合生成存放数据的路径DATA_DIR
,第一个if
判断该路径是否存在,如果不存在则通过mkdir
生产该路径。第二个if
判断DATA_DIR
路径下'modelnet40_ply_hdf5_2048'
是否存在,若不存在则下载文件并解压
此外,该文件中定义了若干函数如下:
def shuffle_data(data, labels):
""" Shuffle data and labels.
Input:
data: B,N,... numpy array
label: B,... numpy array
Return:
shuffled data, label and shuffle indices
"""
idx = np.arange(len(labels))
np.random.shuffle(idx)
return