pycharm加载本地数据集_PyTorch加载自己的数据集

本文详细介绍了如何在PyTorch中加载本地数据集,特别是图片和JSON标注文件的CIFAR10数据集。通过分析torchvision.datasets.CIFAR10的结构,自定义Datasets类,并利用DataLoader进行多线程批量加载,以提高数据读取效率和训练速度。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、前言

在深度学习中,需要加载数据对神经网络进行训练,现有的主流数据集及常用的经典数据集例如COCO,MINIST,CIFAR等,在许多开源的项目中例如MMCV,torchvision中都有对应的加载,对于自己的数据集而言,应该如何加载自定义数据集呢。

torchvision.datasets - PyTorch master documentation​pytorch.org

本文就图片数据集的加载进行分析,数据集为转化为图片和json标注文件的CIFAR10数据集,数据集的文件格式如下所示torchvision.datasets - PyTorch master documentation本文就图片数据集的加载进行分析,数据集为转化为图片和json标注文件的CIFAR10数据集,数据集的文件格式如下所示

b97198636ecff1d08df80c22c25c78ee.png

a5bb0148614794bea2e89ebfb2e9dbff.png

将CIFAR10数据集转换成图片文件和json文件的标注参照这篇文章:

HUST小菜鸡:CIFAR10数据集转换成图片及标注文件​zhuanlan.zhihu.com

二、直接读取

#读取文件位置
def get_path('path-str'):
    ...
    return file_path

#读取图片
def loader_img(file_path):
    #根据图片的位置读取图片并返回读取的图片和标签
    #对图片进行处理
    ...

    return imgs_list, label_list

#获取batchsize大小的数据
def get_train_data(imgs_list,label_list,batchsize):
    ...
    return img[1],img[2],...,img[batchsize]

以上是常规的思路,在原理上来说是可行的,但是如果batchsize很大,那么用这种方式去读取数据集会带来如下弊端:

  • 将所有的图像数据直接加载到numpy数据中会占用大量的内存
  • 由于需要对数据进行导入,每次训练的时候在数据读取阶段会占用大量的时间
  • 只使用了单线程去读取,读取效率比较低下
  • 拓展性很差,只能对数据进行一些单一的预处理

PyTorch中有工具函数torch.utils.Data.DataLoader,通过这个函数我们在准备加载数据集使用mini-batch的时候可以使用多线程并行处理,这样可以加快我们准备数据集的速度。Datasets就是构建这个工具函数的实例参数之一。这样我们就可以批量加载数据或者并行加载数据

三、class Datasets

3eb5c6599cbb5a38d4784b42db09ce0b.png
torch.utils.data - PyTorch master documentation​pytorch.org
class Dataset(object):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值