背景
Pytorch深度学习库的DataLoader函数提供使用多进程加载数据的参数接口num_workers,方便用户可以自定义地设置加载数据的进程个数。在此之前,笔者一直认为 每次从DataLoader中获取一个batch的数据,它就会启动num_workers个进程并行加载数据。但在Pytorch1.7版本以后,DataLoader初始化函数多了一些参数,其中最特别的是prefetch_factor参数。该参数在官方文档中的释义为Number of sample loaded in advance by each worker. 2 means there will be a total of 2 * num_workers samples prefetched across all workers. (default: 2),即每次DataLoader会提前加载prefetch_factor * num_workers个样本数据到内存之中。
让人疑惑不解的是,为什么预读取的样本数是以进程数为单位,而不是以batch的大小作为单位?如果batch的大小远大于进程数的prefetch_factor倍,那岂不是当读取第一个batch时,训练模型同样需要等到DataLoader再将batch中剩余的数据加载到内存之中。这样的话,DataLoader提供的预读取数据功能又有何意义?或者,是我们理解错了进程与batch数据之间的关系,是每一个进程读取一个batch的数据。为了弄清其中的缘由,笔者本人硬着头皮将Pytorch的DataLoader源码进行了一番解析。
DataLoader源码解析
为了避免长篇论述,我在这里只节选展示了部分关键的代码。读者可以参照本文自行阅读Pytorch的官方,以便更深度地理解DataLoader的工作原理。
- DataLoader在初始化函数中,依照num_workers参数实例化相应个数的加载进程 _utils.worker._worker_loop,并在结束前调用self._reset(loader, first_iter=True) 预加载样本。主进程和子进程之间通过 index_queue 和 data_queue (worker_result_queue) 进行通信。
- _reset函数根据预定义加载样本数量prefetch_factor * num_workers将样本索引发放给数据读取进程workers。
- _try_put_index函数先是通过采样器获取一个批次的样本索引(笔者是通过修改源代码,打印出index变量,才知道其每次获取的是一个批次样本的索引,而不是单个样本的索引),然后找到当前未在工作中的worker,将样本索引放入到其对应的index_queue通信队列之中。
- 当index_queue队列中存在索引,_worker_loop进程(函数)便会开启一个处理数据流,并将当前批次的样本索引传给该流。
- 在接收到一个批次的样本索引后,处理数据流就会根据索引逐一地从数据集获取对应的样本。
综合4和5可知,每个worker进程当从索引队列index_queue中获取一个批次的样本索引后,便会启动加载流获取整个批次的数据,然后将其放入到数据队列之中data_queue。也就是说,起码在Pytorch1.7版本之后,DataLoader中每个进程读取的是一个batch数据。所以prefetch_factor * num_workers可以理解为自定义缓冲池的大小,每当有一个batch的数据从缓冲池中被拿走后,DataLoader就会将一批新的样本索引发送给未工作的worker供其读取数据,以便维持缓冲池中样本的数量。由于Pytorch官方的DataLoader中已经提供了预先加载数据的功能,所以我们今后也无需再手动地编程数据预先读取代码。
上述worker加载batch数据的过程也适用于不适用预加载技术的DataLoader。