pytorch之数据集构造

本文深入探讨PyTorch中自定义数据集的方法,解析DataLoader的工作原理及如何构建适合特定任务的数据加载器。文章通过VOC2012数据集分割的实例,演示如何创建和使用DataSet与DataLoader,特别关注于collate_fn函数的定制,以满足不同数据格式的需求。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

这些天看的东西,真的是比较多,相比以前来说,对我的学习方式起到颠覆性作用。我目前觉得,我们学到的东西,更多是孤立的,因此,在吸收一定知识后,需要在脑子里形成知识体系。需要把自己以前学到的东西进行整理,形成一个体系,这篇文章讲解的是,深度学习中pytorch数据集的构造!!!

pytorch中有两个自定义管理数据集的类,

  1. torch.utils.data.DataSet
  2. torvchvision.datasets.ImageFolder

这里主要讲解的第一种。

DataSet源码

class Dataset(object):
    """An abstract class representing a Dataset.

    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])

我们设计自己数据集类的时候, 只需要重写 __getitem__、__len__两个函数,分别的功能是, 通过切片返回具样例返回样本个数
以下是voc2012数据集分割的例子:

import os

import numpy as np
from PIL import Image
from torch.utils import data


def read_images(root, train):
    txt_fname = os.path.join(root, 'ImageSets/Segmentation/') + ('train.txt' if train else 'val.txt')
    with open(txt_fname, 'r') as f:
        images = f.read().split()
    data = [os.path.join(root, 'JPEGImages', i + '.jpg') for i in images]
    label = [os.path.join(root, 'SegmentationClass', i + '.png') for i in images]
    return data, label


class VocSegDataset(data.Dataset):

    def __init__(self, cfg, train, transforms=None):
        self.cfg = cfg
        self.train = train
        self.transforms = transforms
        self.data_list, self.label_list = read_images(self.cfg.DATASETS.ROOT, train)

    def __getitem__(self, item):
        img = self.data_list[item]
        label = self.label_list[item]
        img = Image.open(img)
        # load label
        label = Image.open(label)
        img, label = self.transforms(img, label)
        return img, label

    def __len__(self):
        return len(self.data_list)

通过上面的操作,我们构建自己数据集类,接下来,构建一个 Dataloader类,这个作用是训练过程中,返回 batch个样例。

Dataloder源码

由于源码过于臃肿了,这里知识摘出对应的构造函数:

    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=default_collate,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None):

构造函数中,每个参数的意思就不一一介绍了,只着重的讲解下,可调用函数collate_fn。我们首先看一个构建Dataloader的实例:

def build_dataset(cfg, transforms, is_train=True):
    datasets = VocSegDataset(cfg, is_train, transforms)
    return datasets


def make_data_loader(cfg, is_train=True):
    if is_train:
        batch_size = cfg.SOLVER.IMS_PER_BATCH
        shuffle = True
    else:
        batch_size = cfg.TEST.IMS_PER_BATCH
        shuffle = False

    transforms = build_transforms(cfg, is_train)
    datasets = build_dataset(cfg, transforms, is_train)

    num_workers = cfg.DATALOADER.NUM_WORKERS
    data_loader = data.DataLoader(
        datasets, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True
    )

    return data_loader

上面第一个函数 build_dataset返回数据集实例,第二个函数返回Dataloader,关于Dataloader,我们需要注意的是,有时我们需要根据Dataset中的__getitem__修改collate_fn
我们来看下源码:

    def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = _utils.pin_memory.pin_memory_batch(batch)
            return batch

我们在源码中发现,collate_fn的输入是一个list,里面的每个元素是__getitem__的输出,由此,我们估计,default_collate的作用是将这个list,**变换格式为[batch,C,H,W]**的tensor,我们在来看下源码:

	if.......
	.........
    elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'):  # namedtuple
        return type(batch[0])(*(default_collate(samples) for samples in zip(*batch)))

由于源码均是对类型的判断,因此,这里我们知识摘出,与voc2012分割相关的部分,这个语句的意思是, 对[(img1, label1), (img2, label2)],首先返回[img1,img2],[lable1,label2],在继续返回两个tensor,一个是img,[batch,C,H,W],一个是label:[batch,C,H,W]。
所以,通过上面分析,如果,我们__getitem__不符合collat_fn不符合默认函数的判断时,需要修改该函数。
好了,先到这,接下来…慢慢聊程序,需要学的太多了

### 创建和使用PyTorch数据集 #### 数据准备 为了有效地利用PyTorch进行机器学习项目的数据准备工作,可以采用简单的方法来创建自定义数据集。具体来说,可以通过编写Python脚本来加载图像文件并应用必要的预处理操作[^1]。 ```python import glob import torch from torch.utils import data from PIL import Image import numpy as np from torchvision import transforms import matplotlib.pyplot as plt ``` 这段代码展示了导入所需库的过程,这些库对于后续的操作至关重要。通过`glob`模块获取文件列表;借助PIL库读取图片;运用Numpy执行数组运算;以及利用Matplotlib展示图像效果。同时,还引入了来自`torchvision.transforms`的各种变换工具用于增强数据多样性[^3]。 #### 构建自定义数据集类 当标准的内置数据集无法满足特定需求时,则需设计一个继承于`data.Dataset`的新类——这里命名为`MyDataset`。此类应重写两个核心方法:`__len__()`返回整个集合大小,而`__getitem__(self, idx)`负责按索引访问单条记录[^2]。 ```python class MyDataset(data.Dataset): def __init__(self, image_paths, transform=None): self.image_paths = image_paths self.transform = transform def __len__(self): return len(self.image_paths) def __getitem__(self, index): path = self.image_paths[index] img = Image.open(path).convert('RGB') if self.transform is not None: img = self.transform(img) return img ``` 上述实现了最基础版本的自定义数据集构造器,在初始化阶段接收一系列图像路径作为输入参数,并允许传入额外的转换规则应用于每张载入后的照片上[^4]。 #### 整合与迭代数据 完成以上步骤之后,便能够轻松地把实例化的对象传递给DataLoader来进行批量采样及随机打乱顺序等高级功能[^5]: ```python dataset = MyDataset(image_paths=glob.glob('./images/*.jpg'), transform=transforms.ToTensor()) dataloader = DataLoader(dataset, batch_size=8, shuffle=True) for images in dataloader: # 进行训练或其他处理... pass ``` 此部分演示了如何将之前建立好的`MyDataset`实例化为实际可用的对象,并设置每次取出多少样本组成一批次(`batch_size`)还有是否开启混洗模式(shuffle)以提高泛化能力。最后进入循环结构逐批次取得待用资料供下游任务调用。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值