处理自定义的数据集
1、pytorch 官方模板
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
2、自定义数据集(以Kaggle-ClassifyLeaves为例)
class LeavesData(Dataset):
def __init__(self , csv_path , img_path , mode = 'train' , valid_ratio = 0.2 , resize_height = 256 , resize_weirgt = 256):
'''
Args:
csv_path(string): csv文件路径
img_path(string): 图像文件夹所在路径
mode(string): 训练模式,测试模式
valid_ratio(float) : 验证集比例
'''
self.resize_height = resize_height
self.resize_weirgt = resize_weirgt
self.file_path = img_path
self.mode = mode
self.data_info = pd.read_csv(csv_path,header=None)
self.data_len = len(self.data_info.index)-1
self.train_len = int(self.data_len*(1-valid_ratio))
if mode == 'train':
self.train_image = np.asarray(self.data_info.iloc[1:self.train_len,0])
self.train_label = np.asarray(self.data_info.iloc[1:self.train_len,1])
self.image_arr = self.train_image
self.labe_arr = self.train_label
elif mode == 'valid':
self.valid_image = np.asarray(self.data_info.iloc[self.train_len:,0])
self.valid_label = np.asarray(self.data_info.iloc[self.train_len:,1])
self.image_arr = self.valid_image
self.labe_arr = self.valid_label
elif mode == 'test':
self.test_image = np.asarray(self.data_info.iloc[1:,0])
self.image_arr = self.test_image
self.real_len = len(self.image_arr)
print('Finished reading the {} set of Leaves Dataset {} samples found'.format(mode,self.real_len))
def __getitem__(self,index):
single_image_name = self.image_arr[index]
img_as_img = Image.open(self.file_path + single_image_name)
if self.mode == 'train':
transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.ToTensor()
])
else:
transform = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor()
])
img_as_img = transform(img_as_img)
if self.mode == 'test':
return img_as_img
else:
label = self.labe_arr[index]
number_label = class_to_num[label]
return img_as_img , number_label
def __len__(self):
return self.real_len
train_path = r'E:/Learn_MechineLearning_DeepLearning/pytorch-Learning/Kaggle/Classify-leaves/classify-leaves/train.csv'
test_path = r'E:/Learn_MechineLearning_DeepLearning/pytorch-Learning/Kaggle/Classify-leaves/classify-leaves/test.csv'
img_path = r'E:/Learn_MechineLearning_DeepLearning/pytorch-Learning/Kaggle/Classify-leaves/classify-leaves/'
train_loader = torch.utils.data.DataLoader(
dataset = train_dataset,
batch_size = 16,
shuffle = True
)
val_loader = torch.utils.data.DataLoader(
dataset = val_dataset,
batch_size = 16,
shuffle = True
)
test_loader = torch.utils.data.DataLoader(
dataset = test_dataset,
batch_size = 16,
shuffle = True
)
cnt = 0
for img , label in train_loader:
img = img[0]
img = img.detach().numpy()
img = np.transpose(img, (1,2,0))
plt.imshow(img)
plt.show()
cnt = cnt+1
if cnt == 10:
break