pytorch入门(六)Pytorch数据处理工具

本文介绍了PyTorch中的数据处理工具torch.utils.data,包括Dataset、DataLoader的使用,以及如何自定义数据集。DataLoader用于批量读取和数据增强。另外,详细讲解了Torchvision库,特别是ImageFolder和transforms模块,用于图像数据的加载和预处理,如尺寸调整、随机裁剪、颜色 jitter 和标准化等。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

概述

  • torch.utils.data共有如下四个类。
  1. Dataset:是一个抽象类,其他的数据集需要继承这个类,并且覆盖其中的两个方法(_getitem_,_len_)。
  2. DataLoader:定义一个新的迭代器,实现批量(batch)读取,打乱数据(shuffle)并且提供并行加速等功能。
  3. random_split:把数据集随机拆分为给定长度的非重叠的新数据集。
  4. sampler:多种采样函数
  • Torchvision是PyTorch的一个视觉处理工具包,包括四个子类,如下
  1. dataset:提供常用的数据集的加载,设计上都是继承自torch.utils.data.Dataset,主要包括MMIST、ImageNet、COCO等
  2. models:提供深度学习中的各种经典的网络结构以及训练好的模型,包括AlexNet、VGG、ResNet、Inception等
  3. transforms:常用的数据预处理操作,主要包括对Tensor及PIL Image对象的操作
  4. utils:包含两个函数,一个是make_grid,它能将多张图片拼接在一个网格中;另一个是save_img,它能将Tensor保存成图片

utils.data简介

utils.data 包括Dataset和DataLoader。torch.utils.data.Dataset为抽象类。自定义的数据集需要继承这个类,并且实现两个函数(_getitem_,_len_)前者通过索引获取数据和标签,后者提供数据的大小。_getitem_一次只能获取一个数据,所以需要通过torch.utils.daata.DataLoader来定义一个新的迭代器,实现批量(batch读取)。

import torch
from torch.utils import data
import numpy as  np

# 基于基类Dataset,自定义一个数据集及对应的标签
class TestDataset(data.Dataset): # 继承Dataset
    def __init__(self):
        self.Data = np.asarray([[1, 2], [3, 4], [5, 6]])
        self.Label = np.asarray([0, 1, 0])

    def __getitem__(self, item):
        txt = torch.from_numpy(self.Data[item])
        label = torch.tensor(self.Label[item])
        return txt,label

    def __len__(self):
        return len(self.Data)

# 获取数据集
test = TestDataset()
print(test[2])   # 相当于调用了__getitem__(2)
print(test.__len__())

以上tuple返回,每次只返回一个样本,实际上,Dataset只负责数据的抽取,调用一次__getitem__只返回一个样本。如果希望批处理,还需呀同时进行shuffle和并行加速等操作,可选择DataLoader。格式如下:

data.DataLoader(
    dataset=test,
    batch_size=2,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None,
)

参数说明

  • dataset:加载的数据集
  • batch_size:批大小
  • shuffle:是否将数据打乱
  • sampler:样本抽样
  • num_workers:使用多进程加载的进程数,0代表不使用多进程
  • pin_memory:是否将数据保存在pin memory区,pin memory转化到GPU会更快
  • drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃。

torchvision简介

torchvision有四个功能模块:model、datasets、transforms和utils。下面主要介绍使用 datasets的ImageFolder处理的自定义数据集,以及transforms对源数据进行预处理、增强等。

transforms

transforms提供了对PIL Images对象和Tensor对象的常用操作

  • 对PIL Image的常见操作
    • Scale/Resize:调整尺寸,长款比保持不变
    • CenterCrop、ReandomCrop、RandomSizedCrop:裁剪图片,CenterCrop、ReandomCrop在crop时候是固定size,RandomSizedCrop则是random size的cropRandomSizedCrop变换已弃用,RandomResizedCrop代替
    • pad:填充
    • ToTensor:把一个取值范围为[0,255]的PIL.Image转那位为Tensor。形状为(H,W,C)的Numpy.ndarray转化成形状为[H,W,C],取值范围为[0,1,0]的torch.FloatTensor
    • RandomHorizontalFlip:图像随机水平翻转,翻转概率为0.5
    • RandomVerticalFlip:图像随机垂直翻转
    • ColorJitter:修改亮度、对比度和饱和度
  • 对Tensor的常见操作
    • Normalize:标准化,即,减均值,除以标准差
    • ToPILImage:将Tensor转化为PILImage

如果要对数据集尽心多个操作,可以通过Compose将这些操作像管道一样拼接起来,类似于nn.Sequential如下:

transforms.Compose([
    transforms.CenterCrop(10),
    transforms.RandomCrop(20, padding=0),
    transforms.ToTensor()
])

ImageFolder

当文件依据标签处于不停的文件下,如
——data
    |——data1
        |——001.jpg
        |——002.jpg
        |——003.jpg
    |——data2
        |——001.jpg
        |——006.jpg
我们可以利用torchvision.datasets.ImageFolder来直接构造除dataset,代码如下

loader = torchvision.datasets.ImageFolder(path)
loader = data.DataLoader(dataset)

ImageFolder会自动将目录中的文件夹名自动转化成序列,当DataLoader载入时,标签自动就是整数序列了。
方法一

my_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])
train_data = datasets.ImageFolder('./data/train_neg', my_train)
train_loader = data.DataLoader(train_data, batch_size=6, shuffle=False)
for i_batch, img in enumerate(train_loader):

方法二

my_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])

train_data = datasets.ImageFolder('./data/train_neg', my_train)
train_loader = data.DataLoader(
    dataset=train_data,
    shuffle=False,
    batch_size=6
)
for i_batch, img in enumerate(train_loader):
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

啊~小 l i

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值