from torch.utils import data
class MyData(data.Dataset)
def __init__(self):
def function(self):
def __getitem__(self, index):
return []
def __len__(self):
return len
def __init__(self): 这里初始化你所需要的用到的路径,各种配置
def function(self): 这里是定义你自己的内部函数:
def __getitem__(self, index): 这个函数很关键,为迭代过程中取数据
def __len__(self): 这个必须有,返回数据长度
这是一个CV任务的数据example:
class FFDataset(data.Dataset):
def __init__(self, root, info_list):
self.root = root
self.train_list = self.collect_image(info_list, root)
tfms = transforms.Compose([transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
self.transform = tfms
def collect_image(self, info_list, root):
image_path_list = []
frame = pd.read_csv(info_list, delimiter=',', header=None)
for ind in frame.index:
video_name = str(frame.iloc[ind, 1])
img_path = os.path.join(root, video_name)
img_list = os.listdir(img_path)
for img_name in img_list:
image_path_list.append(path)
return image_path_list
def read_image(self, path):
img = cv2.imread(path)
face_scale = np.random.randint(12, 15)
face_scale = face_scale/10.0
face = self.crop_face(path, img, face_scale)
return face
def __getitem__(self, index):
image_path = self.train_list[index]
img = self.read_image(image_path)
img = self.transform(img)
img_type = self.train_list[index].split('/')[2]
label = 1.0 if img_type == 'live_images' else 0.0
mask = np.ones((1, 14, 14), dtype=np.float32) * label
return [img, label, mask]
def __len__(self):
return len(self.train_list)