自定义数据集
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
import torch.nn.functional as F
import os
class img_segData(Dataset):
def __init__(self,img_h=256,img_w=256,path="./data/img_seg",data_file="images",label_file="profiles",
preprocess=True):
'''
数据集初始化
:param img_h: resize图像高度
:param img_w: resize图像宽度
:param path: 数据集路径
:param data_file: 数据特征值文件夹名称
:param label_file: 数据标签文件夹名称
:param preprocess: 是否进行数据预处理
'''
super(img_segData, self).__init__()
self.file_list = os.listdir(path+"/"+data_file)
self.data_file = data_file
self.label_files = label_file
self.path = path
self.img_h = img_h
self.img_w = img_w
self.preprocess = preprocess
pass
def __len__(self):
return len(self.file_list)
def __getitem__(self, item):
img_name = self.file_list[item]
label_name = img_name.split(".")[0]+"-profile.jpg"
label_path = self.path+"/"+self.label_files+"/"+label_name
img_path = self.path+"/"+self.data_file+"/"+img_name
img = Image.open(img_path)
label = Image.open(label_path)
if self.preprocess:
trans_img = transforms.Compose([
transforms.Resize(size=(self.img_w,self.img_h)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))
])
img = trans_img(img)
trans_label = transforms.Compose([
transforms.Resize(size=(self.img_w,self.img_h)),
transforms.ToTensor(),
])
label = trans_label(label)
return img,label
if __name__ == '__main__':
trans_data = img_segData()
img,label = trans_data.__getitem__(5)
print(img.size(),label.size())
label = torch.where(label==