工具学习:Pytorch数据迭代器

1. 问题来源

在学习深度学习时,在训练时为了使用并行计算(小批量)来加速数据加载,需要通过迭代器(Dataloader)来加载数据,并将数据打乱(reshuffle)来避免过拟合:

from torch.utils import data
from d2l import torch as d2l
from torchvision import datasets
from torchvision.transformer import ToTensor
batch_size = 256

# 加载训练数据
mnist_train = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()

def get_dataloader_workers():
    """使用4个进程来读取数据"""
    return 4

# 设计数据迭代器
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, 
                             num_workers=get_dataloader_workers())

# 通过时间测试,发现读取训练数据所需的时间会与一些因素有关
# 查看读取训练数据所需时间
timer = d2l.Timer()
for X, y in train_iter:
    continue

f'{timer.stop():.2f} sec'

例如,将数据批量大小设置为1时,加载训练数据所需的时间是15s,而数据批量大小设置为256时,加载训练数据所需时间是2s。当然,影响数据迭代器性能的因素很多,除了batch_size可能还有其他因素,因此,根据学习资料中作者的引导,让我开始思考如下问题:
数据迭代器的性能非常重要。当前的实现足够快吗?探索各种选择来改进

2. Pytorch中的数据迭代器是什么?

Pytorch中的数据迭代器为DataLoader,DataLoader是一个可迭代对象,为了根据批量大小来随机获取数据,以实现并行加载数据,提升数据训练效率。

from torch.utils.data import DataLoader

# 创建训练数据迭代器,批量大小为64,打乱数据,采用的进程数量为4
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True, num_workers=4)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True, num_workers=4)

根据pytorch中DataLoader类的定义1,来分析DataLoader的结构。

2.1 初始化参数

DataLoader的常用初始化参数如下表所示:

初始化参数参数说明
dataset必需参数,指定要加载的数据集(Dataset子类实例)
batch_size每个批次的样本数量,默认为1
shuffle是否在每个epoch开始时打乱数据顺序,默认为False
sampler自定义采样器,用于控制数据的采样方式。若指定此参数,shuffle会被忽略
num_workers用于数据加载的子进程数量,默认是0(即主进程加载)
collate_fn将多个样本整理成批次的函数,默认将数据转换为张量

2.2 核心组件

为了实现将数据分为不同大小的批次喂给模型,DataLoader的主要核心组件包括以下几个:
(1)采样器与分批逻辑

  • 采样器(Sampler):负责生成数据样本的索引序列,若未指定,则根据shuffle参数创建默认采样器
  • 批次采样器(BatchSampler): 结合sampler和batch_size参数,将采样器生成的索引划分为多个批次

(2)多进程数据加载
通过设置num_workers参数:

  • 通过多进程并行加载数据,避免I/O阻塞,提升数据加载速度。
  • 每个子进程独立从dataset中获取数据,并通过collate_fn整理后返回主进程

(3)数据整理与批次生成
collate_fn的作用:

  • 默认情况下,collate_fn将多个样本(如tensor、numpy array或字典)按维度拼接成批次张量
  • 用户可以自定义collate_fn以处理复杂数据格式(如图像、文本序列等)

(4)迭代器接口

  • DataLoader实现了迭代器协议(__iter__和__next__方法),支持通过for循环遍历批次数据
  • 迭代时,DataLoader通过采样器和批次采样器生成索引,多进程加载数据,最终返回整理后的批次

2.3 关于迭代器的简单实现

《动手学深度学习》一书中给出了关于迭代器的简单实现版本,能很好地理解其工作原理:

def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    # 这些样本是随机读取的,没有特定的顺序
    random.shuffle(indices)  # 打乱索引
    for i in range(0, num_examples, batch_size):  # 根据批量大小和索引选取数据
        batch_indices = torch.tensor(
            indices[i: min(i + batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]

3. 如何提升数据迭代器性能?

根据对数据迭代器的分析,想要提升数据迭代器的性能,最直观的两个方法是改变批量大小和进程数量,我们先来实验一下,为了不受随机环境的影响,将每个参数实验十次,取平均值:

测试方法

### 设计一个函数,将每个参数实验10次,然后返回平均时间
time = []
def experement(batch_size=256, num_workers=4):
    train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, 
                             num_workers=num_workers)
    timer = d2l.Timer()
    for X, y in train_iter:
        continue
    return timer.stop()

for _ in range(10):
    time.append(experement())

print(sum(time)/len(time))

测试结果

测试1:改变批量大小

批量大小进程数量运行时间
6442.31s
12842.22s
25642.12s
51242.53s

测试2:改变进程数量

批量大小进程数量运行时间
25622.29s
25642.12s
25662.60s
25682.98s

好像一开始选择的参数batch_size=256,num_works=4已经是最优参数了…


后续我将继续关注新的方法来优化数据迭代器,也欢迎各位大佬给点建议!


  1. torch.utils.data.DataLoader.py ↩︎

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

会编程的加缪

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值