使用Pytorch中的,Dataset , DataLoader类去加载数据集:
import torch
from torchvision import transforms, datasets
import os,sys
from torch.utils.data import Dataset,DataLoader
from PIL import Image
import numpy as np
data_transform = transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std = [0.229, 0.224, 0.225])
])
class my_dataset(Dataset):
def __init__(self,img_path, mask_path, data_transform=None):
self.img_path = img_path
self.mask_path = mask_path
self.transforms = data_transform
self.img_list, self.mask_list = [],[]
for file in os.listdir(self.img_path):
img_path = os.path.join(self.img_path, file)
self.img_list.append(img_path)
for file in os.listdir(self.mask_path):
mask_path = os.pa