概述
Pytorch对于数据集的操作使用 DataSet 和 DataLoader 。
- DataSet是一个抽象类,不能被实例化,只能由其他类继承,构造自己的类。
- DataLoader需要知道每一个数据的索引,以及数据的长度。是可以实例化的。
一般来说,我们在使用Pytorch来加载数据集时,需要通过继承 DataSet 类,并实现其中的两个抽象方法,一个是支持下标索引操作的__getitem__
,一个是支持获取数据长度的__len__
。
在加载数据时,一般使用Mini-Batch,原因有以下两个:
(1)通过使用batch,可以在梯度下降中更新参数时通过随机梯度下降的方法,这可以帮助我们跨越部分鞍点。(鞍点会导致梯度为0)
(2)通过将数据划分为batch,可以充分利用计算机并行计算的能力,加快计算的速度。
对于Mini-Batch,有这么几个概念:Epoch, Batch-Size, Iterations
(1) Epoch:所有训练示例经过一次向前传播和一个反向传播。
(2)Batch-Size:一次向前向和反向传播中训练示例的数量。
(3)Iteration:(等于Epoch / Iteration)。一个批次批输入到模型中,称为一个Iteration
代码示例
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class DiabetesDataset(Dataset): #创建自己的类,继承抽象类DataSet
def __init__(self): #初始化方法,构造函数
pass
def __getitem__(self, index): #支持索引的下标操作
pass
def __len__(self): #获取数据的长度
pass
dataset = DiabetesDataset() #实例化对象
train_loader = DataLoader(dataset=dataset, #加载数据集
batch_size=32, # 设置batch-size的大小
shuffle=True, # 设置是否打乱
num_workers=2) # 读的时候并行的进程
注意:数据集不太大时,直接读进去,加载到数据集中是可以的。但是如果数据集很大,例如图像这种可以达到G级别的,需要在初始化函数中做其他的一些处理,让其每次只读取到其中的一部分,避免其全部加载到内存当中。
Pytorch的官方数据集
Pytorch官网中datasets里的数据集,都提供了DataSet的两个抽象方法,因此可以用直接用DataLoader加载。