关于mnist,可参考「学习笔记」torchvision.datasets.MNIST 参数解读/中文使用手册
在学习torch.utils.data.DataLoader的时候偶然发现这个Loader可传参数还蛮多,在PyTorch中文文档中未能搜索到这个Loader,故网上收集、翻译在此,仅做笔记之用。
英文手册如下:
PS:找资料的过程中找到了一篇完成好的博文:https://blog.youkuaiyun.com/rogerfang/article/details/82291464,摘录部分,将笔者学习笔记一并整理如下,感谢原作者辛勤付出。
init(构造函数)中的几个重要的属性:
1、dataset:(数据类型 dataset)
顾名思义,PyTorch中的数据集类型。
2、batch_size:(数据类型 int)
每次输入数据的行数,默认为1。PyTorch训练模型时调用数据不是一个一个进行的(这样太没效率),而是一捆一捆来的。这里就是定义每次喂给神经网络多少行数据,如果设置成1,那就是一行一行进行(个人偏好,PyTorch默认设置是1)。
3、shuffle:(数据类型 bool)
洗牌。默认设置为False。在每次迭代训练时是否将数据洗牌,默认设置是False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。
4、collate_fn:(数据类型 callable)
将一小段数据合并成数据列表,默认设置是False。如果设置成True,系统会在返回前会将张量数据(Tensors)复制到CUDA内存中。(不太明白作用是什么,就暂时默认False)
笔者注:用样本列表合并一个mini-bacth,通常在map型数据加载中使用
5、batch_sampler:(数据类型 Sampler)
批量采样,默认设置为None。但每次返回的是一批数据的索引(注意:不是数据)。其和batch_size、shuffle 、sampler and drop_last参数是不兼容的。我想,应该是每次输入网络的数据是随机采样模式,这样能使数据更具有独立性质。所以,它和一捆一捆按顺序输入,数据洗牌,数据采样,等模式是不兼容的。
6、sampler:(数据类型 Sampler)
采样,默认设置为None。根据定义的策略从数据集中采样输入。如果定义采样规则,则洗牌(shuffle)设置必须为False。
7、num_workers:(数据类型 Int)
工作者数量,默认是0。使用多少个子进程来导入数据。设置为0,就是使用主进程来导入数据。
注:可以理解为子进程数,官方文档中用的单词是subprocesses
8、pin_memory:(数据类型 bool)
默认为False。在数据返回前,是否将数据复制到CUDA内存中。
注:如果数据元素是自定义类型,或者collate返回自定义类型的批处理
9、drop_last:(数据类型 bool)
丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。
如果数据集大小不能被 batch_size 整除,则设置为True可删除最后一个未完成的batch。如果为False,并且数据集的大小不能被 batch_size 整除,则最后一个batch将更小。(默认值:False)
10、timeout:(数据类型 numeric)
超时,默认为0。是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。
11、worker_init_fn(数据类型 callable,没见过的类型)
子进程导入模式,默认为Noun。在数据导入前和步长结束后,根据工作子进程的ID逐个按顺序导入数据。