#Dataloader_Dir#
from torch.utils.data import DataLoader import torchvision def basic(): test_data = torchvision.datasets.CIFAR10(root="D:/dev/python/pyWork/Season2/Stage1/data/myimg", train=False, transform=torchvision.transforms.ToTensor()) test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False) img, target = test_data[0] print(img.shape) print(target) for data in test_loader: imgs, targets = data print(imgs.shape) print(targets) return None if __name__ == "__main__": """ 参数: batch_size: 每次获取数据的数量 | shuffle: 每次获取数据时是否打乱顺序获取, 默认False不打乱 num_workers: 进程量, 默认为0, 只有主线程(在win下有问题) | drop_last: 若剩余的数据量小于每次设定的获取量, 决定是否舍去, 默认False不舍去 """ # 代码1:基本使用 basic() print()
#Dataset_AC#
from torch.utils.data import Dataset from PIL import Image import os class MyData(Dataset): def __init__(self, root_dir, label_dir): self.root_dir = root_dir self.label_dir = label_dir self.path = os.path.join(self.root_dir, self.label_dir) self.img_path = os.listdir(self.path) def __getitem__(self, idx): img_name = self.img_path[idx] img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) img = Image.open(img_item_path) label = self.label_dir return img, label def __len__(self): return len(self.img_path) root_dir = "D:/dev/resources/MLR/hymenoptera_data/hymenoptera_data/train" ants_label_dir = "ants" bees_label_dir = "bees" ants_dataset = MyData(root_dir, ants_label_dir) bees_dataset = MyData(root_dir, bees_label_dir) train_dataset = ants_dataset + bees_dataset