Pytorch:加载图像数据集的方法

目录

1. 简介

2. torchvision.datasets.MNIST

3. torchvision.datasets.ImageFolder

4. torchvision.datasets.DatasetFolder

5. torch.utils.data.Datasets


1. 简介

在机器学习中,模型在每一轮训练前,会选择是否打乱数据,训练过程中呢,需要按批次去加载数据集(这里以图像分类为例),并喂给模型等,一系列管理数据集的操作。

在Pytorch深度学习框架中,上述提到的数据集管理功能的实现,通常都需要经过以下两个步骤:

     1、定义数据集torch以及torchvision中提供了多种方法来简化数据集定义的过程。

     2、Dataloader数据加载器:通过torch.utils.data.Dataloader实例化数据加载器,传入数据集,并配置相关参数,如批次大小等,以此获得一个可迭代对象,在使用for循环遍历。

第一个步骤,定义数据集,包括多种实现方式,如下:

  1torchvision.datasets.ImageFolder:用于标准的开源数据集。

  2、torchvision.datasets.ImageFolder:从文件夹中加载图像数据,自动生成标签。

  3、torchvision.datasets.DatasetFolder:更通用的工具,适用于自定义图像数据集。其中,图像和标签不一定按文件夹结构组织。

  4、torch.utils.data.Dataset:一个抽象基类,用户通过重写__init__、__len__和 __getitem__ 方法以提供数据和标签。

第二个步骤,实例化数据加载迭代器 torch.utils.data.Dataloader 类,涉及到的主要参数:

torch.utils.data.Dataloader(dataset,..)
  • dataset :数据集

  • batch_size :批处理数量

  • shuffle :每完成一个epoch,是否需要重新打乱数据

  • num_worker:采用多进程读取机制

  • collate_fn:可自定义函数,用于将一批数据合并成一个批次,默认为 None

  • drop_last :当样本数不能被batch_size整除时,是否舍弃最后一个batch的数据

在了解完数据集加载的两步骤后,主要变化的是第一步骤,如何定义数据集。所以,接下来还是围绕不同的数据集定义方式,实现最终的数据加载。 

2. torchvision.datasets.MNIST

目前,torchvision.datasets 库中已经收录了多种类型的数据集,一般都是各个图像处理领域内的开源标准数据集,如下列举了一些较为常见的数据集。

  • 图像分类:MNIST,CIFAR10, CIFAR100,ImageNet
  • 目标检测:COCO,VOC
  • 图像分割:COCO,VOC

这种开源数据集的加载,还是非常简单的,因为大佬们都已经封装好方法了,直接调用API就实现了。这里以mnist手写数字识别数据集为例,代码如下。

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 数据转换
transform = transforms.Compose([transforms.ToTensor()])

# 下载 MNIST 数据集
train_dataset = datasets.MNIST(root='mnist_datasets', train=True, download=True,transform=transform)

test_dataset = datasets.MNIST(root='mnist_datasets', train=False, download=True,transform=transform)

#打印数据集第0项
print(train_dataset[0])

# 创建数据加载器,传入数据集,生成可迭代对象
train_loader = DataLoader(dataset=train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=256, shuffle=False)

# 使用for循环迭代输出
for images, labels in train_loader:
    print("images:",images.shape)
    print("labels",labels.shape)

代码执行后的结果,首先在定义的root目录下,下载了mnist数据集 。

 

终端打印了train_dataset数据集中的第1个元素,使用索引,即可检索内容。其中,数据集中每一项数据的格式如下:

(img_tensor,label_index),img_tensor是读取到的图片(tensor格式),label_index是图片对应的模型标签(模型中用到的标签,与现实中定义的标签不同,后续会讲)。

另外,终端迭代输出了每一批次数据的形状,数据加载器中设置的batch_size = 256 ,而每一张图像形状为(1,28,28)。

解释下,前面提到的模型标签与现实中真实的标签。可在debug模式下,调试上面代码,打开定义的数据集train_dataset,其中:

classes:该列表变量中,存储的是真实的标签。

class_to_index:影射了真实标签与模型标签的关系,可以看到模型标签是以阿拉伯数字命名,从0开始依次递增+1。

总结:训练时喂入的分类标签,是以阿拉伯数字,从0开始依次递增+1,这样的命名规则的。所以,在模型训练和推理阶段,模型输出的标签依然是阿拉伯数字,这时候定义的class_to_index就有作用了,将模型推理出的阿拉伯数字标签转化为真正的类名。 

3. torchvision.datasets.ImageFolder

torchvision.datasets.ImageFolder 用于从文件夹中加载图像数据集,指定根目录下的每一个子文件夹表示一个类别。该方法通常用于图像分类任务,并且可以很方便地使用Dataloader来管理批量数据。

文件夹的目录结构如下,root表示根目录,class_0和class_1是以类名命名的文件夹,里面分别包含属于该类的图像。

root/

        class_0/

                images1.jpg

                images2.jpg

                ....

        class_1/

                images1.jpg

                images2:jpg

                ....

        ....

我这里的根目录 root,是“ ../ mnist / train",见下图。共有10个文件夹,也就是10类,其中第10类,类名为 ”九“,是我特意修改的,同样也是为了验证真实标签与模型标签之间的关系。

这是第一类 0 文件夹下的数据,均为手写数字0 的图片。 

接下里可直接使用代码加载遍历该数据集。 

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义数据预处理操作
transform = transforms.Compose([
    transforms.ToTensor(),          # 将图像转换为张量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])

# 创建ImageFolder数据集,根目录用了绝对路径
dataset = datasets.ImageFolder(root='F:\Amode\datasets\mnist\train', transform=transform)

# 打印数据集中第一项
print(dataset[0])

# 创建DataLoader数据加载器
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

#按照常例,迭代遍历数据
for images,labels in data_loader:
    print("images:",images.shape)
    print("lables",labels)

执行代码,终端打印信息,首先还是数据集中的第一项,内容格式仍然是:

(img_tensor,label_index)

同样,更简便的方式,大家用debug模式调试代码,或者直接以 dataset.class_to_index方式输出该变量。

个人觉得,对于分类数据集,这种加载方式是非常容易和轻松的。前提是需要将数据集整理成上面固定的结构 。

4. torchvision.datasets.DatasetFolder

torchvision.datasets.DatasetFolder 是一个比 ImageFolder 更灵活的类,而ImageFolder继承的父类就是它,它允许你自定义加载数据的方式,自定义数据集的结构。

因为比较灵活百变,更难理解和掌握。接下来先了解下该类的详解说明。

下面这部分内容是初始化参数。

  • root :数据集的根目录。
  • loader: 可自定义读取数据样本的方法,该方法传入参数是样本的路径。
  • extension: 扩展名,指的是图片的后缀类型,以元组形式入参。
  • is_valid_file :(可调用对象),获取文件路径并核实文件是否有效,和extension必须有一项有效才行。
  • allow_empty: True 允许空文件被认为是一个类,False反之。

既然ImageFolder的父类就是它,可以先用该类实现ImageFolder中要求的文件夹目录结构(结构在第3部门有说明)。以下代码和ImagesFolde的实现效果一致,数据集目录在第3小节。

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image

#自定义的图像读取方式
def custom_load(path):
    return Image.open(path).convert("RGB")

# 定义数据预处理操作
transform = transforms.Compose([
    transforms.ToTensor(),          # 将图像转换为张量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])

# 创建ImageFolder数据集,根目录用了绝对路径
dataset = datasets.DatasetFolder(
     root=r'F:\Amode\datasets\mnist\train',
     loader= custom_load,
     transform=transform,
     extensions=("jpg","png")
)

# 打印数据集中第一项
print(dataset[0])

# 创建DataLoader数据加载器
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

#按照常例,迭代遍历数据
for images,labels in data_loader:
    print("images:",images.shape)
    print("lables",labels)

假设,换种数据集的目录结构呢,这里举例一种比较常见的结构,如下图所示。

所有图片都在同一目录下,且图片文件名称以 label_name的格式命名,即标签在文件名中体现。

接下来是实现的代码,新定义了一个类,继承DatasetsFolder类,重新定义了父类中的find_class,make_dataset函数。想具体了解这两个函数的,可点进父类源码中去看。 

find_class:输入根目录root;输出classes(列表),所有的真实标签(str),输出class_to_idx(字典),键为真实标签,值为模型标签(阿拉伯数字)。

make_dataset:输入是父类初始化的参数;输出样本列表,格式为[(file_path,class_index),........]每一个样本使用元组表示,元组中存储的是图像的路径,模型标签。

import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image

#自定义的图像加载方式
def custom_load(path):
    return Image.open(path).convert("RGB")

# 定义数据预处理操作
transform = transforms.Compose([
    transforms.ToTensor(),          # 将图像转换为张量
    transforms.Resize((224, 224)),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])

class custom_DatasetFolder(datasets.DatasetFolder):
    #重写find_classes函数
    def find_classes(self, directory):
        """
        传参:根目录;
        输出:classes = [] ,classes_to_idx = {class:index}
        """
        lables = set()
        lables_to_indexs = {}
        #获取目录下文件列表
        file_list = os.listdir(self.root)
        #遍历文件列表
        for f in file_list:
            #从文件名中分离出标签
            lable = f.split('_')[0]
            #添加到集合中,集合不允许重复元素
            lables.add(str(lable))
        #生成真实标签label与类别索引class的映射字典
        for i,l in enumerate(list(lables)):
            lables_to_indexs[l] = int(i)
        return list(lables),lables_to_indexs

    def make_dataset(self,directory,class_to_idx,extensions,is_valid_file,allow_empty,):
        """
            传参;
            输出:sample[(path,class),......]
        """

        #获取目录下的文件列表
        file = os.listdir(directory)
        samp = []
        #遍历文件
        for f in file:
            #分离出标签和文件后缀
            lab = f.split('_')[0]
            sufix = f.split('.')[-1]
            #文件后缀满足扩展要求
            if sufix in extensions:
                #根据标签找到类别class
                cls = class_to_idx[lab]
                #文件完整路径
                file_path = os.path.join(directory,f)
                #每个样本以(path,class)格式添加到列表中
                samp.append((str(file_path),cls))
        return samp



# 创建ImageFolder数据集,根目录用了绝对路径
dataset = custom_DatasetFolder(
     root=r'F:\Amode\datasets\image_data',
     loader= custom_load,
     transform=transform,
     extensions=("jpg","png")
)

# 打印数据集中第一项
print(dataset[0])

# 创建DataLoader数据加载器
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

#按照常例,迭代遍历数据
for images,labels in data_loader:
    print("images:",images.shape)
    print("lables",labels)

任意结构的数据集,都可以使用基类DatasetFolder实现,主要还是通过覆盖上面两个函数,实现获取标签类别变量,样本的路径和类别,以及自定义加载图片的格式。 

5. torch.utils.data.Datasets

继上面内容,这是唯一一个使用torch,定义数据集的方式。

翻译一下上面的内容:

该类是一个抽象类,所有表示从键到数据样本映射的数据集都应继承此类。所有子类应重写 __getitem__ 方法,以支持根据给定的键获取数据样本。子类还可以选择性地重写 __len__ 方法,这通常会返回数据集的大小,torch.utils.data.Sampler 实现和 torch.utils.data.DataLoader 的默认选项都期望这个方法的存在。子类还可以选择性地实现 __getitems__ 方法,以加速批量样本的加载。该方法接受一个样本索引的列表,并返回这些样本的列表。

那什么叫抽象类呢?

抽象类是一种不能直接实例化的类,主要用于定义方法的基本结构和要求,其作为父类呢,通常让子类去继承它,并且在子类中必须实现这个抽象类中定义的方法,也就是具体的实现交给子类。

本节中用到的基类torch.utils.data.Datasets,需要实现以下三种方法。

  • __init__: 初始化数据集对象,通常在这里加载和处理数据。
  • __len__: 返回数据集的大小(样本数量)。
  • __getitem__: 根据给定的索引返回数据集中的样本和标签,这项是必须的。

这部分的演示代码,使用的是还是上一小节中的数据集 ,数据集格式和实现代码如下。

rom torch.utils.data import Dataset
from PIL import Image
import os


class CustomDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        """
        Args:
            image_folder : 图像所在文件夹的路径
            transform : 应用于样本的转换操作
        """
        self.image_folder = image_folder
        self.transform = transform
        self.class_to_idx = {}
        self.image_files = [f for f in os.listdir(image_folder) if f.endswith('.jpg')]
        self.__class_to_idx()

    def __len__(self):
        """返回数据集中的样本数量"""
        return len(self.image_files)

    def __class_to_idx(self):
        labels = set()
        for file in os.listdir(self.image_folder):
            if file.endswith('.jpg'):
                label = file.split('_')[0]
                labels.add(str(label))
        for i,l in enumerate(labels):
            self.class_to_idx[l] = int(i)


    def __getitem__(self, idx):
        """
        Args:
            idx (int): 索引
        Returns:
            dict: 包含图像和标签的字典
        """
        img_name = os.path.join(self.image_folder, self.image_files[idx])
        image = Image.open(img_name).convert('RGB')

        if self.transform:
            image = self.transform(image)

        # 标签从文件名中提取
        lab_name = self.image_files[idx].split('_')[0]
        label = self.class_to_idx[lab_name]
        return image, label


from torch.utils.data import DataLoader
from torchvision import transforms

# 定义转换操作
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# 实例化自定义数据集
dataset = CustomDataset(image_folder='F:\Amode\datasets\image_data', transform=transform)

# 创建 DataLoader
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)

print(dataset[0])

# 使用 DataLoader 遍历数据
for images, labels in data_loader:
    # 在这里进行训练或测试操作
    print(images.size(), labels)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值