
一、数据集的处理与加载
class CatDogDataset(Dataset):
def __init__(self, data_dir, mode="train", split_n=0.9, rng_seed=620, transform=None):
self.mode = mode
self.data_dir = data_dir
self.rng_seed = rng_seed
self.split_n = split_n
self.data_info = self._get_img_info() # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform
def __getitem__(self, index):
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
if len(self.data_info) == 0:
raise Exception("\ndata_dir:{} is a empty dir! Please checkout your path to images!".format(self.data_dir))
return len(self.data_info)
def _get_img_info(self):
img_names = os.listdir(self.data_dir)
img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
random.seed(self.rng_seed)
random.shuffle(img_names)
img_labels = [0 if n.startswith('cat') el