示例:
from torch.utils.data import DataLoader
train_loader = DataLoader(dataset=train_data, batch_size=batch, shuffle=True, num_worker=4)
valid_loader = DataLoader(dataset=valid_data, batch_size=batch, num_worker=4)
1、num_workers是加载数据(batch)的线程数目
num_workers通过影响数据加载速度,从而影响训练速度。每轮dataloader加载数据时:dataloader一次性创建num_worker个worker,worker就是普通的工作进程,并用batch_sampler将指定batch分配给指定worker,worker将它负责的batch加载进RAM。然后,dataloader从RAM中找本轮迭代要用的batch,如果找到了,就使用。如果没找到,就要num_worker个worker继续加载batch到内存,直到dataloader在RAM中找到目标batch。
num_worker个worker --> RAM