目录
Pytorch的数据集
Pytorch深度学习库以一种可读性强、模块化程度高的方式来构建深度学习网络。在构建深度学习网络时,数据的加载和预处理是一项重要而繁琐的工作。如果在构建网络中, 我们需要为加载样本数据、样本数据预处理编写大量的处理代码,会导致代码变得混乱、网络构建过程不清晰,最终难以维护。
基于以上考虑,Pytorch将数据集和数据集的加载定义为两个单独对象,使数据集代码和模型训练代码相分离,以获得更好的可读性和模块化。
Pytorch提供了两个DataSet和DataLoader两个类。
DataSet
DataSet是数据集对象类, Pytorch提供了大量的默认数据集, 包括Fashion-MINST、CIFAR-10、CIFAR-100、CelebA等数据集。如果用户想要加载自定义的数据只需要继承DataSet类。
Pytorch支持两种类型的DataSet:
- Map类型DataSet
- Iterable类型DataSet
Map类型DataSet
Map类型DataSet实现__getitem__()
和 __len__()
,表示从索引/键到数据样本的映射。数据集在使用 访问时,可以通过索引直接获取相关样本数据。例如,dataset[idx]表示使用
idx
从磁盘上的文件夹中读取第i个图像及其相应的标签。
Iterable类型DataSet
IterableDataset
实现了__iter__()函数
,可对数据样本进行迭代访问。这种类型的数据集特别适用于随机读取代价高昂以及批量大小取决于获取的数据等场景。