加载本地数据
数据集: Stanford Dogs
操作:
- dataset :
from torch.utils.data import Dataset, DataLoader
class Mydata(Dataset):
def __init__(self, root, train=True, transform=None,
target_transform=None):
self.root = root
self.train = train
self.transform = transform
self.target_transform = target_transform
file_list_label = os.listdir(self.root)
self.label = []
self.data = []
for index, i in enumerate(file_list_label):
file_list_img = os.listdir(self.root + '/' + i)
for j in file_list_img:
imge = Image.open(root + '/' + i + '/' + j).convert('RGB')
self.label.append(index)
self.data.append(imge)
def __getitem__(self, index):
img, target = self.data[index], self.label[index]
if self.transform is not None:
imgee = self.transform(img)
#print(imgee.shape)
if self.target_transform is not None:
target = self.target_transform(target)
return imgee, target
def __len__(self):
return len(self.data)
mydata = Mydata('./imgg', transform=transforms.Compose([transforms.Resize((32, 32)), transforms.ToTensor()]))
注意transforms.Compose的先后顺序。
torchvision.transforms解析
- datasetloader :
cifarLoader = torch.utils.data.DataLoader(cifarSet, batch_size= 10, shuffle= False, num_workers= 2)