Pytorch 数据加载与数据预处理

本文介绍了PyTorch中数据加载的基本方法,包括torchvision.datasets自带的数据集和自定义数据集的加载方式。同时详细解释了如何利用ImageFolder类加载分类图像数据,以及如何创建自定义数据集类。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

数据加载分为加载torchvision.datasets中的数据集以及加载自己使用的数据集两种情况。

torchvision.datasets中的数据集

torchvision.datasets中自带MNIST,Imagenet-12,CIFAR等数据集,所有的数据集都是torch.utils.data.Dataset的子类,都包含 _ _ len _ (获取数据集长度)和 _ getItem _ _ (获取数据集中每一项)两个子方法。
这里写图片描述
Dataset源码如上,可以看到其中包含了两个没有实现的子方法,之后所有的Dataet类都继承该类,并根据数据情况定制这两个子方法的具体实现。

因此当我们需要加载自己的数据集的时候也可以借鉴这种方法,只需要继承torch.utils.data.Dataset类并重写 init ,len,以及getitem这三个方法即可。这样组着的类可以直接作为参数传入到torch.util.data.DataLoader中去。
以CIFAR10为例 源码

class torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.
train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop
target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
加载自己的数据集

对于torchvision.datasets中有两个不同的类,分别为DatasetFolder和ImageFolder,ImageFolder是继承自DatasetFolder。
下面我们通过源码来看一看folder文件中DatasetFolder和ImageFolder分别做了些什么

import torch.utils.data as data
from PIL import Image
import os
import os.path


def has_file_allowed_extension(filename, extensions):  //检查输入是否是规定的扩展名
    """Checks if a file is an allowed extension.

    Args:
        filename (string): path to a file

    Returns:
        bool: True if the filename ends with a known image extension
    """
    filename_lower = filename.lower()
    return any(filename_lower.endswith(ext) for ext in extensions)


def find_classes(dir):
    classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] //获取root目录下所有的文件夹名称

    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))} //生成类别名称与类别id的对应Dictionary
    return classes, class_to_idx


def make_dataset(dir, class_to_idx, extensions):
    images = []
    dir = os.path.expanduser(dir)// 将~和~user转化为用户目录,对参数中出现~进行处理
    for target in sorted(os.listdir(dir)):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
   
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值