Pytorch重写Dataset类

该博客介绍了如何重写`Dataset`类以处理多个文件夹中的图像,根据指定的类标签进行二分类。`ImageDataset`类接受数据路径、目标类别和可选的图像变换。它遍历文件夹,将属于指定类别的图像标记为1,否则标记为0。同时,提供了预处理图像的方法,如随机裁剪、水平翻转和归一化。博客还给出了训练和验证阶段的图像变换示例。

重写Dataset类,对多个文件夹指定相同标签

import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image


class ImageDataset(Dataset):

    """
    通过包含数据路径和要指定为分类的类名
    data_path:数据路径,到train或者val
    img_true:二分类中,指定的类
    transform: 数据处理,对图像进行随机裁剪, 以及转换成tensor。外界定义好是train的tranform或者是val的transform,然后传进来。也可以在里面定义
    """

    def __init__(self, data_path, img_true, transform=None):
        super(ImageDataset,self).__init__()
        data_floder = os.listdir(data_path)
        data = []
        for item in data_floder:
            img_list = os.listdir(os.path.join(data_path,item))
            
            for img in img_list:
                if item == img_true:
                    data.append((os.path.join(data_path,item,img),int(1)))
                else:
                    data.append((os.path.join(data_path,item,img),int(0)))
                
        self.data = data
        self.transform = transform # 对输入图像进行预处理,这里并没有做,预设为None

       
        
    def __getitem__(self,idx):
        image_name, labels = self.data[idx]
        image = Image.open(image_name)
        if self.transform:
            data = self.transform(image)
        return data,labels
                 
    def __len__(self):
        return len(self.data)



# data_transform = {
#     "train": transforms.Compose([transforms.RandomResizedCrop(224),  # 随机裁剪到224*224像素大小
#                                 transforms.RandomHorizontalFlip(),   # 在水平方向上随机翻转
#                                 transforms.ToTensor(),               #  转换为tensor
#                                 transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]),  # 标准化处理
#     "val": transforms.Compose([transforms.RandomResizedCrop(224),
#                              transforms.ToTensor(),
#                              transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])}
# # 设置自己存放的数据集位置,并plot展示        


# dataset = ImageDataset("/py_workspace/data_set/flower_data/train","daisy")

# print(dataset.__getitem__(0))
# # plt.imshow(imageloader.__getitem__(0)) # 以图像形式展示
# plt.show()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值