Pytorch DataLoader中的num_workers (选择最合适的num_workers值)

一、概念

num_workers是Dataloader的概念,默认值是0。是告诉DataLoader实例要使用多少个子进程进行数据加载(和CPU有关,和GPU无关)
如果num_worker设为0,意味着每一轮迭代时,dataloader不再有自主加载数据到RAM这一步骤(因为没有worker了),而是在RAM中找batch,找不到时再加载相应的batch。缺点当然是速度慢。

当num_worker不为0时,每轮到dataloader加载数据时,dataloader一次性创建num_worker个worker,并用batch_sampler将指定batch分配给指定worker,worker将它负责的batch加载进RAM。

num_worker设置得大,好处是寻batch速度快,因为下一轮迭代的batch很可能在上一轮/上上一轮…迭代时已经加载好了。坏处是内存开销大,也加重了CPU负担(worker加载数据到RAM的进程是CPU复制的嘛)。num_workers的经验设置值是自己电脑/服务器的CPU核心数,如果CPU很强、RAM也很充足,就可以设置得更大些。

num_worker小了的情况,主进程采集完最后一个worker的batch。此时需要回去采集第一个worker产生的第二个batch。如果该worker此时没有采集完,主线程会卡在这里等。(这种情况出现在,num_works数量少或者batchsize
比较小,显卡很快就计算完了,CPU对GPU供不应求。)

即,num_workers的值和模型训练快慢有关,和训练出的模型的performance无关

Detectron2的num_workers默认是4

二、选择最合适的num_workers值

最合适的num_works值与数据集有关
最好是跑代码之前先用这段script跑一下,选择最合适的num_workers值

from time import time
import multiprocessing as mp
import torch
import torchvision
from torchvision import transforms
 
 
transform = transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
 
trainset = torchvision.datasets.MNIST(
    root='dataset/',
    train=True,  #如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
    download=True, #如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
    transform=transform
)
 
print(f"num of CPU: {mp.cpu_count()}")
for num_workers in range(2, mp.cpu_count(), 2):  
    train_loader = torch.utils.data.DataLoader(trainset, shuffle=True, num_workers=num_workers, batch_size=64, pin_memory=True)
    start = time()
    for epoch in range(1, 3):
        for i, data in enumerate(train_loader, 0):
            pass
    end = time()
    print("Finish with:{} second, num_workers={}".format(end - start, num_workers))

在这里插入图片描述
可以看到,这个服务器24个CPU, 最合适的num_workers值是14

三、可能出现的问题

在这里插入图片描述
linux系统中可以使用多个子进程加载数据,windows系统里是不可以的,可以发现报错时产生在DataLoader文件中的。我们找到自己调用DataLoader的文件中num_workers的设置,设置为0或者采用默认为0的设置。

### 解决 PyTorch DataLoader `num_workers` 引发的 Broken Pipe 错误 #### 设置合适的 `num_workers` 在 Windows 系统上,当 `PyTorch` 的 `DataLoader` 使用多线程 (`num_workers>0`) 加载数据时可能会遇到 `BrokenPipeError: [Errno 32] Broken pipe` 错误[^2]。一种解决方案是减少或设置 `num_workers=0` 来禁用多进程数据加载,但这会影响性能。 另一种方法是在初始化 `DataLoader` 之前设定启动方式为 `'spawn'` 或者 `'forkserver'` 而不是默认的方式: ```python import torch.multiprocessing as mp mp.set_start_method('spawn') ``` 这可以有效防止某些情况下发生的管道破裂错误[^4]。 #### 修改 Python 多处理环境变量 对于 Windows 用户来说,可以通过修改环境变量来解决问题。具体做法是在程序运行前加入以下代码片段: ```python import os os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning' ``` 此操作会忽略一些警告信息并可能间接修复问题[^5]。 #### 更新 PyTorch 版本 有时该问题是由于特定版本中的 bug 导致,在这种情况下更新到最新稳定版可能是最简单有效的解决办法之一[^1]。 #### 配置持久工作者选项 如果希望保持较高的工作效率而不降低 `num_workers` 数量,则可尝试开启 `persistent_workers=True` 参数配置项。需要注意的是这个特性仅适用于 `num_workers > 0` 的情况,并且能够显著提高效率因为子进程不会随着每次 epoch 结束而销毁重建[^3]。 ```python dataloader = torch.utils.data.DataLoader( dataset, batch_size=64, shuffle=True, num_workers=4, persistent_workers=True ) ``` 以上措施应该能帮助缓解甚至彻底消除由 `num_workers` 设定不当所引起的 `BrokenPipeError` 报错现象。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值