重写Dataset类,对多个文件夹指定相同标签
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
class ImageDataset(Dataset):
"""
通过包含数据路径和要指定为分类的类名
data_path:数据路径,到train或者val
img_true:二分类中,指定的类
transform: 数据处理,对图像进行随机裁剪, 以及转换成tensor。外界定义好是train的tranform或者是val的transform,然后传进来。也可以在里面定义
"""
def __init__(self, data_path, img_true, transform=None):
super(ImageDataset,self).__init__()
data_floder = os.listdir(data_path)
data = []
for item in data_floder:
img_list = os.listdir(os.path.join(data_path,item))
for img in img_list:
if item == img_true:
data.append((os.path.join(data_path,item,img),int(1)))
else:
data.append((os.path.join(data_path,item,img),int(0)))
self.data = data
self.transform = transform # 对输入图像进行预处理,这里并没有做,预设为None
def __getitem__(self,idx):
image_name, labels = self.data[idx]
image = Image.open(image_name)
if self.transform:
data = self.transform(image)
return data,labels
def __len__(self):
return len(self.data)
# data_transform = {
# "train": transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪到224*224像素大小
# transforms.RandomHorizontalFlip(), # 在水平方向上随机翻转
# transforms.ToTensor(), # 转换为tensor
# transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]), # 标准化处理
# "val": transforms.Compose([transforms.RandomResizedCrop(224),
# transforms.ToTensor(),
# transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])}
# # 设置自己存放的数据集位置,并plot展示
# dataset = ImageDataset("/py_workspace/data_set/flower_data/train","daisy")
# print(dataset.__getitem__(0))
# # plt.imshow(imageloader.__getitem__(0)) # 以图像形式展示
# plt.show()
该博客介绍了如何重写`Dataset`类以处理多个文件夹中的图像,根据指定的类标签进行二分类。`ImageDataset`类接受数据路径、目标类别和可选的图像变换。它遍历文件夹,将属于指定类别的图像标记为1,否则标记为0。同时,提供了预处理图像的方法,如随机裁剪、水平翻转和归一化。博客还给出了训练和验证阶段的图像变换示例。

被折叠的 条评论
为什么被折叠?



