用torch.utils.data构建自己的数据读取器

创建数据读取类

自己创建的类需要继承 data.Dataset,同时必须要重载__getitem__以及 __len__函数

import torch.utils.data as data
class PNetDataSet(data.Dataset):   
  def __init__(self, imgs, transforms=None):
      '''
      目标:获取所有图片路径,并根据训练、验证、测试划分数据
      '''
      self.imgs = imgs
      self.transforms = transforms

  def __getitem__(self, index):
        '''
        返回一张图片的数据 (输入图片矩阵,该图片中gt的标签,gt的box坐标,landmark的坐标)
        '''
        line = self.imgs[index] #./DATA/12/positive/343776.jpg 1 0.12 0.05 -0.11 -0.01
        line=line.strip().split(" ")
        img_path="../prepare_data/"+line[0][1:]#../prepare_data/DATA/12/positive/343776.jpg
        print("img_path:",img_path)
        data = Image.open(img_path)
        ImageFile.LOAD_TRUNCATED_IMAGES = True
        if len(data.getbands())!=3:
        #   print("{} is not RGB".format(img_path))
            data = data.convert('RGB')
        data = self.transforms(data)
        label = int(line[1])
        gt_bbox,gt_landmark=self.__getXy__(line)
        return data, label,gt_bbox,gt_landmark
  #获取每一行的box数据以及landmark数据
  def __getXy__(self,info):
        gt_bbox = dict()
        gt_landmark= dict()
        gt_bbox['xmin'] = 0
        gt_bbox['ymin'] = 0
        gt_bbox['xmax'] = 0
        gt_bbox['ymax'] = 0
        gt_landmark['xlefteye'] = 0
        gt_landmark['ylefteye'] = 0
        gt_landmark['xrighteye'] = 0
        gt_landmark['yrighteye'] = 0
        gt_landmark['xnose'] = 0
        gt_landmark['ynose'] = 0
        gt_landmark['xleftmouth'] = 0
        gt_landmark['yleftmouth'] = 0
        gt_landmark['xrightmouth'] = 0
        gt_landmark['yrightmouth'] = 0 
        if len(info) == 6:
            gt_bbox['xmin'] = float(info[2])
            gt_bbox['ymin'] = float(info[3])
            gt_bbox['xmax'] = float(info[4])
            gt_bbox['ymax'] = float(info[5])
        if len(info) == 12:
            gt_landmark['xlefteye'] = float(info[2])
            gt_landmark['ylefteye'] = float(info[3])
            gt_landmark['xrighteye'] = float(info[4])
            gt_landmark['yrighteye'] = float(info[5])
            gt_landmark['xnose'] = float(info[6])
            gt_landmark['ynose'] = float(info[7])
            gt_landmark['xleftmouth'] = float(info[8])
            gt_landmark['yleftmouth'] = float(info[9])
            gt_landmark['xrightmouth'] = float(info[10])
            gt_landmark['yrightmouth'] = float(info[11])  
        return gt_bbox,gt_landmark

  def __len__(self):
      '''
      返回数据集中所有图片的个数
      '''
      return len(self.imgs)

实现该类

label_file文件内容为
在这里插入图片描述

label_file ='../prepare_data/DATA/imglists/PNet/train_PNet_landmark.txt'
#获取数据加载器
dataload=readPnetData(label_file=label_file,batch_size=1)
for i,(img,label,gt_bbox,gt_landmark) in enumerate(dataload):
  print(i)
  if i==3:
      break
  print("img",img.shape)
  print("label",label)
  print("gt_bbox:",gt_bbox)
  print("gt_landmark:",gt_landmark)

def readPnetData(label_file,batch_size,num_workers=1):
    print("label_file is {}".format(label_file))  
    f = open(label_file, 'r')
    samples=f.readlines()
    print("Total size of the dataset is: ", len(samples))
    img_size=12
    data_transforms = {
        'train': transforms.Compose([
            transforms.ColorJitter(brightness=1,contrast=1,hue=0.5),# 随机从 0 ~ 2 之间亮度变化,1 表示原图
            transforms.RandomGrayscale(p=0.5),    # 以0.5的概率进行灰度化
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            #transforms.Resize((img_size,img_size),3),
            # transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        }
#将整个数据内容传进该类
image_datasets=PNetDataSet(samples,data_transforms['train'])
DataLoaderS = torch.utils.data.DataLoader(image_datasets,batch_size=batch_size,shuffle=True,num_workers=num_workers) 
    return DataLoaderS

打印结果

关键点:每次可能读入多个读片,但是打印的时候需要时间,不能完全打印出来数据加载器已经加载的数据。
例:
img_path: …/prepare_data//DATA/12/positive/98860.jpg对应的label=1
img_path: …/prepare_data//DATA/12/negative/754620.jpg 对应的label=0
img_path: …/prepare_data//DATA/12/negative/826367.jpg 对应的label=0
但是只打印出来前面两个数据(打印两个数据时此时已经读入了三个数据)

label_file is ../prepare_data/DATA/imglists/PNet/train_PNet_landmark.txt
Total size of the dataset is:  1428957
img_path: ../prepare_data//DATA/12/positive/98860.jpg
img_path: ../prepare_data//DATA/12/negative/754620.jpg
0
img torch.Size([1, 3, 12, 12])
img_path: ../prepare_data//DATA/12/negative/826367.jpg
label tensor([1])
gt_bbox: {'xmax': tensor([-0.1900], dtype=torch.float64), 'ymax': tensor([-0.0600], dtype=torch.float64), 'ymin': tensor([0.0100], dtype=torch.float64), 'xmin': tensor([-0.1100], dtype=torch.float64)}
gt_landmark: {'xnose': tensor([0]), 'xlefteye': tensor([0]), 'xrightmouth': tensor([0]), 'yrighteye': tensor([0]), 'ylefteye': tensor([0]), 'xleftmouth': tensor([0]), 'ynose': tensor([0]), 'yrightmouth': tensor([0]), 'xrighteye': tensor([0]), 'yleftmouth': tensor([0])}
1
img torch.Size([1, 3, 12, 12])
img_path: ../prepare_data//DATA/12/train_PNet_landmark_aug/43001.jpg
label tensor([0])
gt_bbox: {'xmax': tensor([0]), 'ymax': tensor([0]), 'ymin': tensor([0]), 'xmin': tensor([0])}
gt_landmark: {'xnose': tensor([0]), 'xlefteye': tensor([0]), 'xrightmouth': tensor([0]), 'yrighteye': tensor([0]), 'ylefteye': tensor([0]), 'xleftmouth': tensor([0]), 'ynose': tensor([0]), 'yrightmouth': tensor([0]), 'xrighteye': tensor([0]), 'yleftmouth': tensor([0])}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值