DataLoader 源码阅读 讲解
DataLoader
torch.utils.data.DataLoader(*dataset*, *batch_size=1*, *shuffle=None*, *sampler=None*, *batch_sampler=None*, *num_workers=0*, *collate_fn=None*, *pin_memory=False*, *drop_last=False*, *timeout=0*, *worker_init_fn=None*, *multiprocessing_context=None*, *generator=None*, ***, *prefetch_factor=None*, *persistent_workers=False*, *pin_memory_device=''*)
- DataLoader结合了Dataset和Sampler,在给定的Dataset上提供一个对应的迭代器以获取其中的数据。其中Dataset包含的就是所需要的数据,而Sampler则返回的是一系列索引,二者结合起来得到的迭代器就是根据索引来获得对应的数据,乱序的索引可以用来训练,而顺序的索引通常在验证或测试中使用;
- num_workers是用来加载数据的进程数量,如果为0的话,会使用主进程来加载数据,这样数据的加载和训练是串行的,即数据加载的时候会阻塞模型训练,而为1的话,就会有一个独立的进程来加载数据,使得数据的加载和训练是并行的。主进程在训练模型时,还会有一个进程来加载数据,但也会额外地使用一些内存、CPU资源。
在了解DataLoader的处理流程之前,最好还是对其内部的一些类进行了解。
Dataset
Dataset有两种类型,一种是map-style
的,可以通过索引随机访问的集合。每个数据样本可以通过索引来直接获取。通过__getitem__
方法来获得索引对应的数据。另外一种是Iterable-style
的,是一种只能通过迭代顺序访问的数据集,适用于无法提前加载到内存或索引访问的数据流。常用于生成数据的情况,例如读取大型数据文件、从流式数据源中获取数据等。通常来说都是使用map-style
的。
使用map-style
的需要子类继承Dataset
类,并重写__getitem__
方法以及__len__
方法。其中__getitem__
方法用来实现给定索引获得索引对应的数据的功能,而__len__
用来为Sampler提供服务,让其知道生成索引序号的上限和应当生成的索引的数量。
Sampler
Sampler的作用是返回一系列的索引,可能是顺序的索引,也可能是乱序的索引。其输入是数据集,返回的是数据的索引。通常来说,如果一个数据集的大小为N,那么SequentialSampler返回的就是1,2,3...,N
,而乱序返回的长度也为N,因为是乱序的,所以每个位置的索引不能确定。
在代码中,其并不是直接返回一个列表,而是一个迭代器,通过不停地遍历迭代器来得到索引。
根据是否允许放回采样(在Sampler中的参数为replacement),可能会产生相同的索引值,如果是不放回采样,通常不会有相同的索引值。
SequentialSampler
class SequentialSampler(Sampler[int]):
r"""Samples elements sequentially, always in the same order.
Args:
data_source (Dataset): dataset to sample from
"""
data_source: Sized
def __init__(self, data_source: Sized) -> None:
self.data_source = data_source
# 直接返回一个生成顺序序列的迭代器
def __iter__(self) -> Iterator[int]:
return iter(range(len(self.data_source)))
def __len__(self) -> int:
return len(self.data_source)
RandomSampler
class RandomSampler(Sampler[int]):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify :attr:`num_samples` to draw.
Args:
data_source (Dataset): dataset to sample from
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
num_samples (int): number of samples to draw, default=`len(dataset)`.
generator (Generator): Generator used in sampling.
"""
data_source: Sized
replacement: bool
def __init__(self, data_source: Sized, replacement: bool = False,
num_samples: Optional[int] = None, generator=None) -> None:
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.generator = generator
if not isinstance(self.replacement, bool):
raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
@property
def num_samples(self) -> int:
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = self.generator
if self.replacement:
# 分批生成索引,防止一次性生成过多索引导致内存占用过高
for _ in range(self.num_samples // 32):
# 使用torch.randint实现重复采样
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else:
for _ in range(self.num_samples // n):
# 使用torch.randperm,直接生成随机序列,其中yield from相当于遍历后面的内容,并将后面的内容yield
yield from torch.randperm(n, generator=generator).tolist()
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
def __len__(self) -> int:
return self.num_samples
BatchSampler
用来包装Sampler并返回多个(通常是多个)长度mini-batch的索引列表,其会将Sampler返回的序列按照batch的大小多次返回,例如长度为10的序列按照batch大小为3进行返回可以返回4次,得到四个索引列表。其输入是一个Sampler,输出是多个mini-batch的索引列表。
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
注意:BatchSampler的长度不是数据的长度,而是数据长度整除batch大小的结果。
class BatchSampler(Sampler[List[int]]):
r"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler or Iterable): Base sampler. Can be any iterable object
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""
def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None:
# Since collections.abc.Iterable does not check for `__getitem__`, which
# is one way for an object to be an iterable, we don't do an `isinstance`
# check here.
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError(f"batch_size should be a positive integer value, but got batch_size={batch_size}")
if not isinstance(drop_last, bool):
raise ValueError(f"drop_last should be a boolean value, but got drop_last={drop_last}")
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self) -> Iterator[List[int]]:
# Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951
if self.drop_last:
sampler_iter = iter(self.sampler)
while True:
try:
# 以batch的频率调用sampler,并返回
batch = [next(sampler_iter) for _ in range(self.batch_size)]
yield batch
except StopIteration:
break
else:
batch = [0] * self.batch_size
idx_in_batch = 0
# self.sampler会返回N个索引
for idx in self.sampler:
batch[idx_in_batch] = idx
idx_in_batch += 1
# 当数量到batch时,就将该索引列表返回
if idx_in_batch == self.batch_size:
yield batch
idx_in_batch = 0
batch = [0] * self.batch_size
if idx_in_batch > 0:
yield batch[:idx_in_batch]
def __len__(self) -> int:
# Can only be called if self.sampler has __len__ implemented
# We cannot enforce this condition, so we turn off typechecking for the
# implementation below.
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
if self.drop_last:
return len(self.sampler) // self.batch_size # type: ignore[arg-type]
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type]
collate_fn
用于批处理数据的函数,将多个数据合并成一个批次的数据。数据集中的一个数据可能包含多个不同的“字段”,例如多模态的数据可能包含文本、图像和标签,而我们需要将相同的字段打包成一个batch,方便处理。例如让所有文本数据在一个列表中,图像数据在一个列表中。collate_fn的输入就是一个batch的数据,返回的是将相同字段打包的batch。
该函数依靠一个字典实现,该字典定义了类型及其对应方法的映射,当数据中存在字典中包含的数据类型时,就会使用其对应的方法处理该类型的数据(字段):
default_collate_fn_map: Dict[Union[Type, Tuple[Type, ...]], Callable] = {torch.Tensor: collate_tensor_fn}
with contextlib.suppress(ImportError):
import numpy as np
# For both ndarray and memmap (subclass of ndarray)
default_collate_fn_map[np.ndarray] = collate_numpy_array_fn
# See scalars hierarchy: https://numpy.org/doc/stable/reference/arrays.scalars.html
# Skip string scalars
default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn
default_collate_fn_map[float] = collate_float_fn
default_collate_fn_map[int] = collate_int_fn
default_collate_fn_map[str] = collate_str_fn
default_collate_fn_map[bytes] = collate_str_fn
# collate_fn会调用该函数
def collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
elem = batch[0]
# 根据batch中第一个元素,判断数据的类型,然后根据数据的类型决定如何处理整个batch
elem_type = type(elem)
# 先判断batch中的元素是不是collate_fn_map中的,如果是,按照其中定义的方法执行
if collate_fn_map is not None:
if elem_type in collate_fn_map:
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
for collate_type in collate_fn_map:
if isinstance(elem, collate_type):
return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map)
if isinstance(elem, collections.abc.Mapping):
try:
return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
except TypeError:
# The mapping type may not support `__init__(iterable)`.
return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch)))
# 如果是序列类型,则先通过zip函数,将相同的字段打包在一起
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
# 将相同字段合并在一起
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
# 然后将相同字段的列表送入collate函数
if isinstance(elem, tuple):
return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] # Backwards compatibility.
else:
try:
return elem_type([collate(samples, collate_fn_map=collate_fn_map) for samples in transposed])
except TypeError:
# The sequence type may not support `__init__(iterable)` (e.g., `range`).
return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
# 如果batch中的元素是张量类型的话,就创建第0维。将整个batch的张量在第0维拼接起来
def collate_tensor_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
elem = batch[0]
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem._typed_storage()._new_shared(numel, device=elem.device)
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
return torch.stack(batch, 0, out=out)
# 如果batch中的元素是numpy类型的话,将整个batch中的数据先转为tensor类型,然后交由collate进一步处理
def collate_numpy_array_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
elem = batch[0]
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)
# 如果batch中的元素是标量,直接将整个batch转为tesnor
def collate_numpy_scalar_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
return torch.as_tensor(batch)
# 如果batch中的元素是浮点数,直接将整个batch转为tensor
def collate_float_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
return torch.tensor(batch, dtype=torch.float64)
# 如果batch中的元素int,直接将整个batch转为tensor
def collate_int_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
return torch.tensor(batch)
# 如果batch中的元素是字符串,那么不做任何处理直接返回batch,通常是字符串组成的元祖
def collate_str_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
return batch
DataLoaderIter
DataLoaderIter是DataLoader内部的迭代器,负责实际的数据加载过程。DataLoader的__iter__
方法就是返回一个Iter对象。其会将上述提到的Dataset,Sampler,BatchSampler,collate_fn整合为了一个流程。根据num_workers
数量,会有单进程的Iter,也有多进程的Iter。
创建Iter需要Dataloader将自身传入作为参数来创建:
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
其初始化方法由其基类_BaseDataLoaderIter
实现,主要是用来将DataLoader的成员变量赋值为其自身的成员变量,而实现遍历时,会首先执行__iter__
返回自身,而其中的__next__
方法则是调用自身_next_data
方法来获得数据,而该方法在基类中没有实现,是子类需要实现的:
class _BaseDataLoaderIter:
def __init__(self, loader: DataLoader) -> None:
self._dataset = loader.dataset
self._shared_seed = None
self._pg = None
if isinstance(self._dataset, IterDataPipe):
if dist.is_available() and dist.is_initialized():
self._pg = dist.new_group(backend="gloo")
self._shared_seed = _share_dist_seed(loader.generator, self._pg)
shared_rng = torch.Generator()
shared_rng.manual_seed(self._shared_seed)
self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng)
self._dataset_kind = loader._dataset_kind
self._IterableDataset_len_called = loader._IterableDataset_len_called
self._auto_collation = loader._auto_collation
self._drop_last = loader.drop_last
self._index_sampler = loader._index_sampler
self._num_workers = loader.num_workers
ws, rank = _get_distributed_settings()
self._world_size = ws
self._rank = rank
# for other backends, pin_memory_device need to set. if not set
# default behaviour is CUDA device. if pin_memory_device is selected
# and pin_memory is not set, the default behaviour false.
if (len(loader.pin_memory_device) == 0):
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
self._pin_memory_device = None
else:
if not loader.pin_memory:
warn_msg = ("pin memory device is set and pin_memory flag is not used then device pinned memory won't be used"
"please set pin_memory to true, if you need to use the device pin memory")
warnings.warn(warn_msg)
self._pin_memory = loader.pin_memory
self._pin_memory_device = loader.pin_memory_device
self._timeout = loader.timeout
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler)
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
self._persistent_workers = loader.persistent_workers
self._num_yielded = 0
self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__"
def __iter__(self) -> '_BaseDataLoaderIter':
return self
def _reset(self, loader, first_iter=False):
self._sampler_iter = iter(self._index_sampler)
self._num_yielded = 0
self._IterableDataset_len_called = loader._IterableDataset_len_called
if isinstance(self._dataset, IterDataPipe):
self._shared_seed = _share_dist_seed(loader.generator, self._pg)
shared_rng = torch.Generator()
shared_rng.manual_seed(self._shared_seed)
self._dataset = torch.utils.data.graph_settings.apply_random_seed(self._dataset, shared_rng)
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
def _next_data(self):
raise NotImplementedError
def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
# TODO(https://github.com/pytorch/pytorch/issues/76750)
self._reset() # type: ignore[call-arg]
data = self._next_data()
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
self._IterableDataset_len_called is not None and \
self._num_yielded > self._IterableDataset_len_called:
warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
"samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
self._num_yielded)
if self._num_workers > 0:
warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
"IterableDataset replica at each worker. Please see "
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
warnings.warn(warn_msg)
return data
def __len__(self) -> int:
return len(self._index_sampler)
def __getstate__(self):
# TODO: add limited pickling support for sharing an iterator
# across multiple threads for HOGWILD.
# Probably the best way to do this is by moving the sample pushing
# to a separate thread and then just sharing the data queue
# but signalling the end is tricky without a non-blocking API
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
而子类实现的方法依赖于一个Fetcher类:
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super().__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
# Adds forward compatibilities so classic DataLoader can work with DataPipes:
# Taking care of distributed sharding
if isinstance(self._dataset, (IterDataPipe, MapDataPipe)):
# For BC, use default SHARDING_PRIORITIES
torch.utils.data.graph_settings.apply_sharding(self._dataset, self._world_size, self._rank)
# 使用该fetcher对象来实现对数据的获取
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
# 首先执行_next_index()调用batchsampler获得batch的索引
index = self._next_index() # may raise StopIteration
# 已经是处理后的batch的数据了
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
return data
# self._sampler_iter为iter(self._index_sampler)
# 其中_index_sampler一般为batchsampler
def _next_index(self):
return next(self._sampler_iter)
# 根据数据集的类型返回不同的Fetcher,一般是MapDataset
class _DatasetKind:
Map = 0
Iterable = 1
@staticmethod
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
if kind == _DatasetKind.Map:
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
else:
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
# 基类方法主要实现初始化功能
class _BaseDatasetFetcher:
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
self.dataset = dataset
self.auto_collation = auto_collation
self.collate_fn = collate_fn
self.drop_last = drop_last
def fetch(self, possibly_batched_index):
raise NotImplementedError()
# 该类用来真正的实现数据的获取功能
class _MapDatasetFetcher(_BaseDatasetFetcher):
def fetch(self, possibly_batched_index):
# 如果设置了batch_sampler
if self.auto_collation:
# 如果有__getitems__,注意是__getitems___不是__getitem__,一般没有实现
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
data = self.dataset.__getitems__(possibly_batched_index)
else:
# 根据提供的batch索引,一个个取出数据
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
# 将取出的batch的数据交给collate_fn处理
return self.collate_fn(data)
流程总结
总的来说,流程如下:
- 将我们自己创建的Dataset送入DataLoder类,执行其初始化方法。其初始化方法会创建对应的sampler,batch_sampler,collate_fn等方法;
- 使用
for data in dataloder
从Dataloader中获取数据,会首先触发其中的__iter__
方法,得到对应的DataLoaderIter对象; - 然后会触发Iter对象的
__next__
方法,该方法调用了Fetcher对象的fetch方法; - 该方法先通过sampler得到索引,然后根据索引获得数据,并通过collate_fn方法将数据打包返回。
此外,在具体训练时,通常会有多个epoch,每个epoch都会通过Dataloader的__iter__
获得一个新的迭代对象,其中使用的RandomSampler的seed也会重新生成。因此,每个epoch开始时,数据顺序通常会和上个epoch的数据顺序不同。
class myDataset(Dataset):
def __init__(self) -> None:
self.data = torch.arange(start=1, end=30)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
dataloader = DataLoader(mydataset, shuffle=True, batch_size=8)
for epoch in range(3):
print("--------")
for data in dataloader:
print(data)
# 可以看到每个epoch输出的数据顺序不同
Output:
--------
tensor([20, 2, 7, 10, 23, 3, 28, 15])
tensor([17, 18, 5, 14, 13, 9, 27, 29])
tensor([16, 11, 6, 1, 12, 24, 4, 19])
tensor([21, 8, 22, 25, 26])
--------
tensor([ 5, 23, 9, 2, 15, 4, 18, 26])
tensor([13, 16, 22, 11, 28, 19, 12, 3])
tensor([17, 10, 27, 20, 24, 25, 1, 7])
tensor([14, 29, 21, 6, 8])
--------
tensor([15, 19, 23, 18, 1, 17, 25, 22])
tensor([ 6, 5, 10, 24, 7, 13, 20, 14])
tensor([ 2, 21, 9, 28, 26, 3, 12, 11])
tensor([29, 27, 8, 4, 16])