1.加载数据
- Dataset
- 提供一种方式去获取数据及其Iabel
- Dataloader
- 为后面的网络提供不同的数据形式
Dataset
1.如何获取每一个数据及其label
2.告诉我们总共有多少的数据
from torch.utils.data import Dataset
from PIL import Image
import os
class myData(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir=root_dir
self.label_dir=label_dir
self.path=os.path.join(self.root_dir,self.label_dir)
self.img_path=os.listdir(self.path)
def __getitem__(self, idx):
img_name=self.img_path[idx]
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
img=Image.open(img_item_path)
label=self.label_dir
return img,label
def __len__(self):
return len(self.img_path)
root_dir="hymenoptera_data/train"
ants_label_dir=&#