Pytorch加载数据集(DataSet和DataLoader)

概述

Pytorch对于数据集的操作使用 DataSet 和 DataLoader 。
- DataSet是一个抽象类,不能被实例化,只能由其他类继承,构造自己的类。
- DataLoader需要知道每一个数据的索引,以及数据的长度。是可以实例化的。
一般来说,我们在使用Pytorch来加载数据集时,需要通过继承 DataSet 类,并实现其中的两个抽象方法,一个是支持下标索引操作的__getitem__,一个是支持获取数据长度的__len__


在加载数据时,一般使用Mini-Batch,原因有以下两个:
(1)通过使用batch,可以在梯度下降中更新参数时通过随机梯度下降的方法,这可以帮助我们跨越部分鞍点。(鞍点会导致梯度为0)
(2)通过将数据划分为batch,可以充分利用计算机并行计算的能力,加快计算的速度。

对于Mini-Batch,有这么几个概念:Epoch, Batch-Size, Iterations
(1) Epoch:所有训练示例经过一次向前传播和一个反向传播。

(2)Batch-Size:一次向前向和反向传播中训练示例的数量。

(3)Iteration:(等于Epoch / Iteration)。一个批次批输入到模型中,称为一个Iteration


代码示例

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class DiabetesDataset(Dataset): #创建自己的类,继承抽象类DataSet
    def __init__(self):	#初始化方法,构造函数
        pass
    def __getitem__(self, index):	#支持索引的下标操作
        pass
    def __len__(self):	#获取数据的长度
        pass
        
    dataset = DiabetesDataset()	#实例化对象
    train_loader = DataLoader(dataset=dataset, #加载数据集
                              batch_size=32, # 设置batch-size的大小
                              shuffle=True, # 设置是否打乱
                              num_workers=2) # 读的时候并行的进程

注意:数据集不太大时,直接读进去,加载到数据集中是可以的。但是如果数据集很大,例如图像这种可以达到G级别的,需要在初始化函数中做其他的一些处理,让其每次只读取到其中的一部分,避免其全部加载到内存当中。


Pytorch的官方数据集


Pytorch官网中datasets里的数据集,都提供了DataSet的两个抽象方法,因此可以用直接用DataLoader加载。

参考资料

[1] https://www.bilibili.com/video/BV1Y7411d7Ys?p=8

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

半岛铁子_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值