创建数据读取类
自己创建的类需要继承 data.Dataset,同时必须要重载__getitem__以及 __len__函数
import torch.utils.data as data
class PNetDataSet(data.Dataset):
def __init__(self, imgs, transforms=None):
'''
目标:获取所有图片路径,并根据训练、验证、测试划分数据
'''
self.imgs = imgs
self.transforms = transforms
def __getitem__(self, index):
'''
返回一张图片的数据 (输入图片矩阵,该图片中gt的标签,gt的box坐标,landmark的坐标)
'''
line = self.imgs[index] #./DATA/12/positive/343776.jpg 1 0.12 0.05 -0.11 -0.01
line=line.strip().split(" ")
img_path="../prepare_data/"+line[0][1:]#../prepare_data/DATA/12/positive/343776.jpg
print("img_path:",img_path)
data = Image.open(img_path)
ImageFile.LOAD_TRUNCATED_IMAGES = True
if len(data.getbands())!=3:
# print("{} is not RGB".format(img_path))
data = data.convert('RGB')
data = self.transforms(data)
label = int(line[1])
gt_bbox,gt_landmark=self.__getXy__(line)
return data, label,gt_bbox,gt_landmark
#获取每一行的box数据以及landmark数据
def __getXy__(self,info):
gt_bbox = dict()
gt_landmark= dict()
gt_bbox['xmin'] = 0
gt_bbox['ymin'] = 0
gt_bbox['xmax'] = 0
gt_bbox['ymax'] = 0
gt_landmark['xlefteye'] = 0
gt_landmark['ylefteye'] = 0
gt_landmark['xrighteye'] = 0
gt_landmark['yrighteye'] = 0
gt_landmark['xnose'] = 0
gt_landmark['ynose'] = 0
gt_landmark['xleftmouth'] = 0
gt_landmark['yleftmouth'] = 0
gt_landmark['xrightmouth'] = 0
gt_landmark['yrightmouth'] = 0
if len(info) == 6:
gt_bbox['xmin'] = float(info[2])
gt_bbox['ymin'] = float(info[3])
gt_bbox['xmax'] = float(info[4])
gt_bbox['ymax'] = float(info[5])
if len(info) == 12:
gt_landmark['xlefteye'] = float(info[2])
gt_landmark['ylefteye'] = float(info[3])
gt_landmark['xrighteye'] = float(info[4])
gt_landmark['yrighteye'] = float(info[5])
gt_landmark['xnose'] = float(info[6])
gt_landmark['ynose'] = float(info[7])
gt_landmark['xleftmouth'] = float(info[8])
gt_landmark['yleftmouth'] = float(info[9])
gt_landmark['xrightmouth'] = float(info[10])
gt_landmark['yrightmouth'] = float(info[11])
return gt_bbox,gt_landmark
def __len__(self):
'''
返回数据集中所有图片的个数
'''
return len(self.imgs)
实现该类
label_file文件内容为
label_file ='../prepare_data/DATA/imglists/PNet/train_PNet_landmark.txt'
#获取数据加载器
dataload=readPnetData(label_file=label_file,batch_size=1)
for i,(img,label,gt_bbox,gt_landmark) in enumerate(dataload):
print(i)
if i==3:
break
print("img",img.shape)
print("label",label)
print("gt_bbox:",gt_bbox)
print("gt_landmark:",gt_landmark)
def readPnetData(label_file,batch_size,num_workers=1):
print("label_file is {}".format(label_file))
f = open(label_file, 'r')
samples=f.readlines()
print("Total size of the dataset is: ", len(samples))
img_size=12
data_transforms = {
'train': transforms.Compose([
transforms.ColorJitter(brightness=1,contrast=1,hue=0.5),# 随机从 0 ~ 2 之间亮度变化,1 表示原图
transforms.RandomGrayscale(p=0.5), # 以0.5的概率进行灰度化
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
#transforms.Resize((img_size,img_size),3),
# transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
#将整个数据内容传进该类
image_datasets=PNetDataSet(samples,data_transforms['train'])
DataLoaderS = torch.utils.data.DataLoader(image_datasets,batch_size=batch_size,shuffle=True,num_workers=num_workers)
return DataLoaderS
打印结果
关键点:每次可能读入多个读片,但是打印的时候需要时间,不能完全打印出来数据加载器已经加载的数据。
例:
img_path: …/prepare_data//DATA/12/positive/98860.jpg对应的label=1
img_path: …/prepare_data//DATA/12/negative/754620.jpg 对应的label=0
img_path: …/prepare_data//DATA/12/negative/826367.jpg 对应的label=0
但是只打印出来前面两个数据(打印两个数据时此时已经读入了三个数据)
label_file is ../prepare_data/DATA/imglists/PNet/train_PNet_landmark.txt
Total size of the dataset is: 1428957
img_path: ../prepare_data//DATA/12/positive/98860.jpg
img_path: ../prepare_data//DATA/12/negative/754620.jpg
0
img torch.Size([1, 3, 12, 12])
img_path: ../prepare_data//DATA/12/negative/826367.jpg
label tensor([1])
gt_bbox: {'xmax': tensor([-0.1900], dtype=torch.float64), 'ymax': tensor([-0.0600], dtype=torch.float64), 'ymin': tensor([0.0100], dtype=torch.float64), 'xmin': tensor([-0.1100], dtype=torch.float64)}
gt_landmark: {'xnose': tensor([0]), 'xlefteye': tensor([0]), 'xrightmouth': tensor([0]), 'yrighteye': tensor([0]), 'ylefteye': tensor([0]), 'xleftmouth': tensor([0]), 'ynose': tensor([0]), 'yrightmouth': tensor([0]), 'xrighteye': tensor([0]), 'yleftmouth': tensor([0])}
1
img torch.Size([1, 3, 12, 12])
img_path: ../prepare_data//DATA/12/train_PNet_landmark_aug/43001.jpg
label tensor([0])
gt_bbox: {'xmax': tensor([0]), 'ymax': tensor([0]), 'ymin': tensor([0]), 'xmin': tensor([0])}
gt_landmark: {'xnose': tensor([0]), 'xlefteye': tensor([0]), 'xrightmouth': tensor([0]), 'yrighteye': tensor([0]), 'ylefteye': tensor([0]), 'xleftmouth': tensor([0]), 'ynose': tensor([0]), 'yrightmouth': tensor([0]), 'xrighteye': tensor([0]), 'yleftmouth': tensor([0])}