PyTorch数据操作:源码级详解Dataset 与DataLoader

部署运行你感兴趣的模型镜像

大家好,今天我会更新Pytorch中数据集的操作方法以及常用的优化方法,以便于明天进行完整的模型训练方案,这一篇我们先讲解dataset与dataloader。

第一章:前言 - 为什么我们需要Dataset 和 DataLoader?

大家好,今天我们来探讨 PyTorch 中处理数据的两个核心工具:Dataset 和 DataLoader。在开始编写代码之前,我们先来思考一个问题:在训练一个深度学习模型时,我们在数据处理上会遇到哪些问题?

数据量庞大

    像 ImageNet 这样的数据集有上百万张图片,总大小超过 100GB。我们不可能一次性把所有数据都读入内存,这会导致内存溢出。

批处理操作麻烦(Batching)

    我们通常使用小批量梯度下降(Mini-batch      Gradient Descent)来训练模型,这意味着我们需要将数据集分成一个个的小批次(batch)。

需要数据打乱(Shuffling)

    为了让模型具有更好的泛化能力,避免训练过拟合,我们希望在每个训练周期(epoch)开始时,都将数据重新打乱顺序。

需要并行处理(Parallelism)

    数据加载和预处理(如图像的缩放、裁剪、归一化)可能会成为训练过程中的瓶颈。如果能利用多核 CPU 并行处理这些任务,就能大大加快训练速度。

面对这些挑战,如果我们手动去实现数据加载、分批、打乱和并行,代码会变得非常复杂且容易出错。

PyTorch 解决了这个问题:它将数据处理的流程拆分成了两个核心部分:

torch.utils.data.Dataset

    负责如何找到并处理单个数据。它告诉程序第 i 个样本是什么,在哪里,如何获取它。

torch.utils.data.DataLoader

    负责如何将数据组合成批次。它从 Dataset 中取出数据,然后自动完成分批、打乱、并行加载等一系列工作。

一个比喻:

假设我们把数据处理当成建房子

Dataset 

    就像一个记录了很多袋水泥位置与基本信息(比如大小)的清单

DataLoader 

    就像你手下的工人。你告诉他:“我需要每批10份水泥(batch_size=10),给我前随机给我(shuffle=True),不需要按大小排列,并且请派 5个帮手一起准备(num_workers=5)。” 你的工人就会根据清单帮你准备准备所需要的所有水泥

接下来,我们将学习如何使用和创建它们。

第二章:核心概念 – Dataset源码级详解

    Dataset 是一个抽象类,我们通常需要继承它来创建自己的数据集类。要创建一个自定义的 Dataset,实现以下两个方法:

__len__(self):

作用

        返回数据集中样本的总数

何时使用:

        DataLoader 需要知道总共有多少数据,以便确定如何划分批次和何时结束一个 epoch。

__getitem__(self, index):

作用:

    根据给定的索引 index,返回一个数据样本。

何时使用

    DataLoader 在构建一个 batch 时,会根据需要(顺序或随机)传入索引 index,调用这个方法来获取单个数据。大部分数据预处理(如图像变换)都在这里完成。

        

实战:创建一个自定义图像数据集

假设有如下的文件夹结构,用于一个猫狗分类任务:

/data/

├── cats/

│   ├── cat.0.jpg

│   ├── cat.1.jpg

│   └── ...

└── dogs/

    ├── dog.0.jpg

    ├── dog.1.jpg

    └── ...

我们可以编写一个源码级别的CustomImageDataset类来加载这些数据。

importos

fromPIL importImage

fromtorch.utils.data importDataset

importtorchvision.transforms as transforms

 

classCustomImageDataset(Dataset):

    def__init__(self, root_dir, transform=None):

        """

       

            root_dir (string): 包含所有图像的根目录。

            transform (callable, optional): 应用于样本的可选变换。

        """

        self.root_dir =root_dir

        self.transform =transform

        self.samples =[]  # 用来存储 (图片路径, 标签)

 

        # 遍历根目录,为猫狗图片打上标签(0猫,1狗)

        forlabel, class_name inenumerate(['cats', 'dogs']):

            class_dir =os.path.join(self.root_dir, class_name)

            forfile_name inos.listdir(class_dir):

                iffile_name.endswith(('.jpg', '.png', '.jpeg')):

                    image_path =os.path.join(class_dir, file_name)

                    self.samples.append((image_path,
  label))

 

    def__len__(self):

        """返回数据集中的样本总数。"""

        returnlen(self.samples)

 

    def__getitem__(self, idx):

        """

        根据索引 idx 获取一个样本。

    

        """

        # 1. 获取图片路径和标签

        image_path, label =self.samples[idx]

 

        # 2. 读取图片

        image =Image.open(image_path).convert('RGB')

 

        # 3. 应用变换 (transform)

        

        ifself.transform:

            image =self.transform(image)

 

        returnimage, label

 

# --- 使用示例 ---

# 定义一些图像变换

# ToTensor() 会将 PIL Image 或 numpy.ndarray 从 (H, W, C) 格式转为 (C, H, W) 格式的 Tensor,

# 并且将像素值从 [0, 255] 缩放到 [0.0, 1.0]

data_transform =transforms.Compose([

    transforms.Resize((128, 128)), # 缩放到 128x128

    transforms.ToTensor(),        
# 转换为 Tensor

    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化

])

 

# 实例化 Dataset

# 假设数据在 './data' 目录下

cat_dog_dataset =CustomImageDataset(root_dir='./data', transform=data_transform)

 

# 验证

print(f"数据总数: {len(cat_dog_dataset)}")

first_image, first_label =cat_dog_dataset[0]

print(f"第一张图Tensor 形状: {first_image.shape}")

print(f"第一张图标签: {first_label}")

在这个例子中:

  • __init__ :负责扫描文件目录,建立一个包含所有图片路径和对应标签的列表      self.samples。
  • __len__ :  直接返回这个列表的长度。
  • __getitem__ : 负责根据索引加载单张图片,并使用      torchvision.transforms 对其进行预处理,最后返回处理好的图片张量(Tensor)和标签。

第三章:核心概念 - DataLoader 详解

DataLoader 是一个迭代器,它在 Dataset 的基础上构建,为我们提供批量、打乱和并行加载数据的能力。

DataLoader 的一些关键参数:

dataset: 我们刚刚创建的 Dataset 实例。

batch_size (int, optional): 每个批次加载的样本数(默认值1)。

shuffle (bool, optional): 如果为 True,则在每个 epoch 开始前都会打乱数据顺序(默认值:False)。在训练时通常设为 True,在验证/测试时设为False。

num_workers (int, optional): 用于数据加载的子进程数。0 表示数据将在主进程中加载(默认值0)。设置为大于 0 的值可以显著加速数据加载,尤其是在数据预处理复杂时。

collate_fn (callable, optional): 一个函数可以用于将 Dataset 中取出的样本列表整理成一个批次。通常情况下,默认的 collate_fn 就能很好地工作,但处理可变长度的序列(如 NLP 任务)时,需要自定义它。

实战:使用 DataLoader 打包数据集

现在,我们把上一节创建的 cat_dog_dataset 喂给 DataLoader。

fromtorch.utils.data importDataLoader

# 假设 cat_dog_dataset 已创建好

# 创建 DataLoader

# 批大小64, 打乱数据, 使用 4 个子进程加载数据

train_loader =DataLoader(

    dataset=cat_dog_dataset,

    batch_size=64,

    shuffle=True,

    num_workers=4

)

注意:在 Windows 或 macOS 上使用 num_workers > 0 时,需要将 DataLoader 的创建和使用代码放在 if __name__ == '__main__': 块中,以避免多进程相关的问题。

第四章:总结

数据处理

    Dataset 关注单个样本的获取与处理;DataLoader 关注样本批次的组合与迭代策略。。

性能瓶颈

    如果 GPU 利用率很低,很可能是数据加载跟不上模型训练的速度。这时,增加num_workers 的值是首要的优化手段。通常可以设置为 CPU 核心数的一半或全部。

内置数据集

    PyTorch 的 torchvision.datasets、torchaudio.datasets 和      torchtext.datasets 已经为我们准备好了许多常用数据集(如 MNIST,      CIFAR10, ImageNet 等),可以直接使用,无需自己编写 Dataset 类。

shuffle=True 用于训练,shuffle=False 用于评估

    训练时需要随机性来提高泛化能力,而评估时需要一个固定的顺序来保证结果的可复现性。

    通过掌握 Dataset 和DataLoader,就掌握了 PyTorch 中数据处理的核心方法

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值