大家好,今天我会更新Pytorch中数据集的操作方法以及常用的优化方法,以便于明天进行完整的模型训练方案,这一篇我们先讲解dataset与dataloader。
第一章:前言 - 为什么我们需要Dataset 和 DataLoader?
大家好,今天我们来探讨 PyTorch 中处理数据的两个核心工具:Dataset 和 DataLoader。在开始编写代码之前,我们先来思考一个问题:在训练一个深度学习模型时,我们在数据处理上会遇到哪些问题?
数据量庞大
像 ImageNet 这样的数据集有上百万张图片,总大小超过 100GB。我们不可能一次性把所有数据都读入内存,这会导致内存溢出。
批处理操作麻烦(Batching)
我们通常使用小批量梯度下降(Mini-batch Gradient Descent)来训练模型,这意味着我们需要将数据集分成一个个的小批次(batch)。
需要数据打乱(Shuffling)
为了让模型具有更好的泛化能力,避免训练过拟合,我们希望在每个训练周期(epoch)开始时,都将数据重新打乱顺序。
需要并行处理(Parallelism)
数据加载和预处理(如图像的缩放、裁剪、归一化)可能会成为训练过程中的瓶颈。如果能利用多核 CPU 并行处理这些任务,就能大大加快训练速度。
面对这些挑战,如果我们手动去实现数据加载、分批、打乱和并行,代码会变得非常复杂且容易出错。
PyTorch 解决了这个问题:它将数据处理的流程拆分成了两个核心部分:
torch.utils.data.Dataset
负责如何找到并处理单个数据。它告诉程序第 i 个样本是什么,在哪里,如何获取它。
torch.utils.data.DataLoader
负责如何将数据组合成批次。它从 Dataset 中取出数据,然后自动完成分批、打乱、并行加载等一系列工作。
一个比喻:
假设我们把数据处理当成建房子
Dataset
就像一个记录了很多袋水泥位置与基本信息(比如大小)的清单
DataLoader
就像你手下的工人。你告诉他:“我需要每批10份水泥(batch_size=10),给我前随机给我(shuffle=True),不需要按大小排列,并且请派 5个帮手一起准备(num_workers=5)。” 你的工人就会根据清单帮你准备准备所需要的所有水泥
接下来,我们将学习如何使用和创建它们。
第二章:核心概念 – Dataset源码级详解
Dataset 是一个抽象类,我们通常需要继承它来创建自己的数据集类。要创建一个自定义的 Dataset,须实现以下两个方法:
__len__(self):
作用
返回数据集中样本的总数
何时使用:
DataLoader 需要知道总共有多少数据,以便确定如何划分批次和何时结束一个 epoch。
__getitem__(self, index):
作用:
根据给定的索引 index,返回一个数据样本。
何时使用:
DataLoader 在构建一个 batch 时,会根据需要(顺序或随机)传入索引 index,调用这个方法来获取单个数据。大部分数据预处理(如图像变换)都在这里完成。
实战:创建一个自定义图像数据集
假设有如下的文件夹结构,用于一个猫狗分类任务:
/data/
├── cats/
│ ├── cat.0.jpg
│ ├── cat.1.jpg
│ └── ...
└── dogs/
├── dog.0.jpg
├── dog.1.jpg
└── ...
我们可以编写一个源码级别的CustomImageDataset类来加载这些数据。
importos
fromPIL importImage
fromtorch.utils.data importDataset
importtorchvision.transforms as transforms
classCustomImageDataset(Dataset):
def__init__(self, root_dir, transform=None):
"""
root_dir (string): 包含所有图像的根目录。
transform (callable, optional): 应用于样本的可选变换。
"""
self.root_dir =root_dir
self.transform =transform
self.samples =[] # 用来存储 (图片路径, 标签)
# 遍历根目录,为猫狗图片打上标签(0猫,1狗)
forlabel, class_name inenumerate(['cats', 'dogs']):
class_dir =os.path.join(self.root_dir, class_name)
forfile_name inos.listdir(class_dir):
iffile_name.endswith(('.jpg', '.png', '.jpeg')):
image_path =os.path.join(class_dir, file_name)
self.samples.append((image_path,
label))
def__len__(self):
"""返回数据集中的样本总数。"""
returnlen(self.samples)
def__getitem__(self, idx):
"""
根据索引 idx 获取一个样本。
"""
# 1. 获取图片路径和标签
image_path, label =self.samples[idx]
# 2. 读取图片
image =Image.open(image_path).convert('RGB')
# 3. 应用变换 (transform)
ifself.transform:
image =self.transform(image)
returnimage, label
# --- 使用示例 ---
# 定义一些图像变换
# ToTensor() 会将 PIL Image 或 numpy.ndarray 从 (H, W, C) 格式转为 (C, H, W) 格式的 Tensor,
# 并且将像素值从 [0, 255] 缩放到 [0.0, 1.0]
data_transform =transforms.Compose([
transforms.Resize((128, 128)), # 缩放到 128x128
transforms.ToTensor(),
# 转换为 Tensor
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 归一化
])
# 实例化 Dataset
# 假设数据在 './data' 目录下
cat_dog_dataset =CustomImageDataset(root_dir='./data', transform=data_transform)
# 验证
print(f"数据总数: {len(cat_dog_dataset)}")
first_image, first_label =cat_dog_dataset[0]
print(f"第一张图Tensor 形状: {first_image.shape}")
print(f"第一张图标签: {first_label}")
在这个例子中:
- __init__ :负责扫描文件目录,建立一个包含所有图片路径和对应标签的列表 self.samples。
- __len__ : 直接返回这个列表的长度。
- __getitem__ : 负责根据索引加载单张图片,并使用 torchvision.transforms 对其进行预处理,最后返回处理好的图片张量(Tensor)和标签。
第三章:核心概念 - DataLoader 详解
DataLoader 是一个迭代器,它在 Dataset 的基础上构建,为我们提供批量、打乱和并行加载数据的能力。
DataLoader 的一些关键参数:
dataset: 我们刚刚创建的 Dataset 实例。
batch_size (int, optional): 每个批次加载的样本数(默认值1)。
shuffle (bool, optional): 如果为 True,则在每个 epoch 开始前都会打乱数据顺序(默认值:False)。在训练时通常设为 True,在验证/测试时设为False。
num_workers (int, optional): 用于数据加载的子进程数。0 表示数据将在主进程中加载(默认值0)。设置为大于 0 的值可以显著加速数据加载,尤其是在数据预处理复杂时。
collate_fn (callable, optional): 一个函数可以用于将 Dataset 中取出的样本列表整理成一个批次。通常情况下,默认的 collate_fn 就能很好地工作,但处理可变长度的序列(如 NLP 任务)时,需要自定义它。
实战:使用 DataLoader 打包数据集
现在,我们把上一节创建的 cat_dog_dataset 喂给 DataLoader。
fromtorch.utils.data importDataLoader
# 假设 cat_dog_dataset 已创建好
# 创建 DataLoader
# 批大小64, 打乱数据, 使用 4 个子进程加载数据
train_loader =DataLoader(
dataset=cat_dog_dataset,
batch_size=64,
shuffle=True,
num_workers=4
)
注意:在 Windows 或 macOS 上使用 num_workers > 0 时,需要将 DataLoader 的创建和使用代码放在 if __name__ == '__main__': 块中,以避免多进程相关的问题。
第四章:总结
数据处理
Dataset 关注单个样本的获取与处理;DataLoader 关注样本批次的组合与迭代策略。。
性能瓶颈
如果 GPU 利用率很低,很可能是数据加载跟不上模型训练的速度。这时,增加num_workers 的值是首要的优化手段。通常可以设置为 CPU 核心数的一半或全部。
内置数据集
PyTorch 的 torchvision.datasets、torchaudio.datasets 和 torchtext.datasets 已经为我们准备好了许多常用数据集(如 MNIST, CIFAR10, ImageNet 等),可以直接使用,无需自己编写 Dataset 类。
shuffle=True 用于训练,shuffle=False 用于评估
训练时需要随机性来提高泛化能力,而评估时需要一个固定的顺序来保证结果的可复现性。
通过掌握 Dataset 和DataLoader,就掌握了 PyTorch 中数据处理的核心方法
739

被折叠的 条评论
为什么被折叠?



