PointNet数据预处理+网络训练

数据预处理

数据预处理,这里仅介绍一个shapenetdataset;

class ShapeNetDataset(data.Dataset):
    def __init__(self,
                 root,
                 npoints=2500,
                 classification=False,
                 class_choice=None,
                 split='train',
                 data_augmentation=True):
        self.npoints = npoints  # 单个数据集的点数
        self.root = root
        self.catfile = os.path.join(self.root,'synsetoffset2category.txt') #各个类别的数据对应的文件夹的路径
        self.cat = {
   
   }
        self.data_augmentation = data_augmentation # 是否进行数据增强
        self.classification = classification       # 数据的种类
        self.seg_classes = {
   
   }

        with open(self.catfile,'r') as f:
            for line in f:
                ls = line.strip().split()
                self.cat[ls[0]] = ls[1]

        if not class_choice is None:
            self.cat = {
   
   k: v for k,v in self.cat.items() if k in class_choice}

        self.id2cat = {
   
   v:k for k,v in self.cat.items()}

        self.meta = {
   
   }

        # 读取已经分类好的数据的地址
        splitfile = os.path.join(self.root,'train_test_split','shuffled_{}_file_list.json'.format(split))

        filelist = json.load(open(splitfile,'r'))
        for item in self.cat:
            self.meta[item] = []

        # 数据存储地址的转换
        for file in filelist:
            _,category,uuid = file.split('/')
            if category in self.cat.values():
                self.meta[self.id2cat[category]].append((os.path.join(self.root,category,'points',uuid+'.pts'),
                                                         os.path.join(self.root, category, 'points_label', uuid+'.seg')))

        #按类别存储数据路径
        self.datapath = []
        for item in self.cat:
            for fn in self.meta[item]:
                self.datapath.append((item,fn[0],fn[1]))

        self.classes = dict(zip(sorted(self.cat),range(len(self.cat))))

        print(self.classes)
        with open(os.path.join(os.path.dirname(os.path.realpath(__file__)),'../misc/num_seg_classes.txt'),'r') as f:
            for line in f:
                ls = line.strip().split()
                self.seg_classes[ls[0]] = int(ls[1])
        self.num_seg_classes = self.seg_classes[list(self.cat.keys())[0]]
        print(self.seg_classes,self.num_seg_classes)

    #__getitem__ 方法通常用于定制类的实例对象的索引访问行为,使得类的实例可以像序列(如列表、元组)或映射(如字典)一样进行索引操作。
    # 在你的代码片段中,这个方法的定义可能是为了支持类实例的索引访问,比如 instance[index] 的操作。
    def __getitem__(self, index):
        fn = self.datapath[index]
        cls = self.classes[self.datapath[index][0]]
        point_set = np.loadtxt(fn[1]).astype(np.float32)
        seg = np.loadtxt((fn[2])).astype(np.int64)

        choice = np.random.choice(len(seg),self.npoints,replace=True)
        #resample
        point_set = point_set[choice,:]
        #去中心化
        point_set = point_set - np.expand_dims(np.mean(point_set,axis=0),0)
        #单位化
        dist = np.max(np.sqrt(np.sum(point_set ** 2,axis = 1)),0)
        point_set = point_set / dist

        # 采用随机旋转和随机高斯抖动对数据进行数据增强
        if self.data_augmentation:
            theta = np.random.uniform(0,np.pi*2)
            rotation_matrix = np.array([[np.cos(theta),-np.sin(theta)],[np.sin(theta),np.cos(theta)]])
            point_set[:,[0,2]] = point_set[:,[0,2]].dot(rotation_matrix) # 随机旋转
            point_set += np.random.normal(0,0.02,size=point_set.shape) # 生成的随机数服从均值为 0,标准差为 0.02 的正态分布

        seg = seg[choice]
        point_set = torch.from_numpy(point_set)
        seg = torch.from_numpy(seg)
        cls = torch.from_numpy(np.array([cls]).astype(np.int64))

        if self.classification:
            return point_set,cls
        else:
            return  point_set,seg

    def __len__(self):
        
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值