DataLoader是每一个worker加载一个batch数据,并不是所有的worker一起加载一个batch数据

文章解析了Pytorch1.7版本后DataLoader的prefetch_factor参数,指出每个进程读取的是一个batch数据,而非单个样本。prefetch_factor相当于自定义缓冲池大小,确保数据预加载以提高性能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

背景

Pytorch深度学习库的DataLoader函数提供使用多进程加载数据的参数接口num_workers,方便用户可以自定义地设置加载数据的进程个数。在此之前,笔者一直认为 每次从DataLoader中获取一个batch的数据,它就会启动num_workers个进程并行加载数据。但在Pytorch1.7版本以后,DataLoader初始化函数多了一些参数,其中最特别的是prefetch_factor参数。该参数在官方文档中的释义为Number of sample loaded in advance by each worker. 2 means there will be a total of 2 * num_workers samples prefetched across all workers. (default: 2),即每次DataLoader会提前加载prefetch_factor * num_workers个样本数据到内存之中。
让人疑惑不解的是,为什么预读取的样本数是以进程数为单位,而不是以batch的大小作为单位?如果batch的大小远大于进程数的prefetch_factor倍,那岂不是当读取第一个batch时,训练模型同样需要等到DataLoader再将batch中剩余的数据加载到内存之中。这样的话,DataLoader提供的预读取数据功能又有何意义?或者,是我们理解错了进程与batch数据之间的关系,是每一个进程读取一个batch的数据。为了弄清其中的缘由,笔者本人硬着头皮将Pytorch的DataLoader源码进行了一番解析。

DataLoader源码解析

为了避免长篇论述,我在这里只节选展示了部分关键的代码。读者可以参照本文自行阅读Pytorch的官方,以便更深度地理解DataLoader的工作原理。

  1. DataLoader在初始化函数中,依照num_workers参数实例化相应个数的加载进程 _utils.worker._worker_loop,并在结束前调用self._reset(loader, first_iter=True) 预加载样本。主进程和子进程之间通过 index_queuedata_queue (worker_result_queue) 进行通信。

DataLoader初始化函数

  1. _reset函数根据预定义加载样本数量prefetch_factor * num_workers将样本索引发放给数据读取进程workers。

在这里插入图片描述

  1. _try_put_index函数先是通过采样器获取一个批次的样本索引(笔者是通过修改源代码,打印出index变量,才知道其每次获取的是一个批次样本的索引,而不是单个样本的索引),然后找到当前未在工作中的worker,将样本索引放入到其对应的index_queue通信队列之中。

在这里插入图片描述

  1. index_queue队列中存在索引,_worker_loop进程(函数)便会开启一个处理数据流,并将当前批次的样本索引传给该流。

在这里插入图片描述

  1. 在接收到一个批次的样本索引后,处理数据流就会根据索引逐一地从数据集获取对应的样本。

在这里插入图片描述
综合4和5可知,每个worker进程当从索引队列index_queue中获取一个批次的样本索引后,便会启动加载流获取整个批次的数据,然后将其放入到数据队列之中data_queue。也就是说,起码在Pytorch1.7版本之后,DataLoader中每个进程读取的是一个batch数据。所以prefetch_factor * num_workers可以理解为自定义缓冲池的大小,每当有一个batch的数据从缓冲池中被拿走后,DataLoader就会将一批新的样本索引发送给未工作的worker供其读取数据,以便维持缓冲池中样本的数量。由于Pytorch官方的DataLoader中已经提供了预先加载数据的功能,所以我们今后也无需再手动地编程数据预先读取代码。

上述worker加载batch数据的过程也适用于不适用预加载技术的DataLoader

下面是一个例子,展示如何用PyTorch自己写一个dataloaderdataloader集成自object对象。 1. 首先,需要导入PyTorchDataLoader和Dataset模块: ``` import torch from torch.utils.data import DataLoader, Dataset ``` 2. 接下来,定义一个自定义的Dataset类,继承自PyTorch的Dataset类,实现__len__和__getitem__函数: ``` class CustomDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, index): return self.data[index] ``` 其中,__init__函数用于初始化数据集,__len__函数用于返回数据集的大小,__getitem__函数用于返回指定索引的数据。 3. 然后,定义一个自定义的DataLoader类,继承自PyTorchDataLoader类,实现__init__和__iter__函数: ``` class CustomDataLoader(DataLoader): def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None): super().__init__(dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, pin_memory, drop_last, timeout, worker_init_fn, multiprocessing_context) def __iter__(self): for batch in super().__iter__(): yield self.transform(batch) def transform(self, batch): # 对批次数据进行变换,这里仅作为示例,不做实际变换 return batch ``` 其中,__init__函数用于初始化DataLoader,__iter__函数用于循环获取批次数据在获取前对数据进行变换(这里仅作为示例,不做实际变换)。 4. 最后,调用CustomDataLoader即可获取一个dataloader: ``` data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] batch_size = 2 dataset = CustomDataset(data) dataloader = CustomDataLoader(dataset, batch_size=batch_size, shuffle=True) for batch in dataloader: print(batch) ``` 这样就可以得到一个dataloader了。在本例中,数据集是一个简单的数字列表,每个批次包含两个数字,dataloader会将数据集分成多个批次,每次输出一个批次的数据。自定义的DataLoader类继承自PyTorchDataLoader类,覆盖了__init__和__iter__函数,以实现自定义的功能。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值