本文记录博主学习心得,如理解有误,请及时指正交流哦。THANKS!
DataLoader是PyTorch中的一种数据类型。
PyTorch中数据读取的一个重要接口是torch.utils.data.DataLoader,该接口定义在dataloader.py脚本中,只要是用PyTorch来训练模型基本都会用到该接口,该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor。(方便产生一个可迭代对象(iterator),每次输出指定batch_size大小的Tensor)
dataloader.py脚本的的github地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py
代码一般是这么写的:
# 定义学习集 DataLoader
train_data = torch.utils.data.DataLoader(各种设置...)
# 将数据喂入神经网络进行训练
for i, (input, target) in enumerate(train_data):
循环代码行..