Pytorch数据处理工具
概述
- torch.utils.data共有如下四个类。
- Dataset:是一个抽象类,其他的数据集需要继承这个类,并且覆盖其中的两个方法(
_getitem_
,_len_
)。 - DataLoader:定义一个新的迭代器,实现批量(batch)读取,打乱数据(shuffle)并且提供并行加速等功能。
- random_split:把数据集随机拆分为给定长度的非重叠的新数据集。
- sampler:多种采样函数
- Torchvision是PyTorch的一个视觉处理工具包,包括四个子类,如下
- dataset:提供常用的数据集的加载,设计上都是继承自torch.utils.data.Dataset,主要包括MMIST、ImageNet、COCO等
- models:提供深度学习中的各种经典的网络结构以及训练好的模型,包括AlexNet、VGG、ResNet、Inception等
- transforms:常用的数据预处理操作,主要包括对Tensor及PIL Image对象的操作
- utils:包含两个函数,一个是make_grid,它能将多张图片拼接在一个网格中;另一个是save_img,它能将Tensor保存成图片
utils.data简介
utils.data 包括Dataset和DataLoader。torch.utils.data.Dataset为抽象类。自定义的数据集需要继承这个类,并且实现两个函数(_getitem_
,_len_
)前者通过索引获取数据和标签,后者提供数据的大小。_getitem_
一次只能获取一个数据,所以需要通过torch.utils.daata.DataLoader来定义一个新的迭代器,实现批量(batch读取)。
import torch
from torch.utils import data
import numpy as np
# 基于基类Dataset,自定义一个数据集及对应的标签
class TestDataset(data.Dataset): # 继承Dataset
def __init__(self):
self.Data = np.asarray([[1, 2], [3, 4], [5, 6]])
self.Label = np.asarray([0, 1, 0])
def __getitem__(self, item):
txt = torch.from_numpy(self.Data[item])
label = torch.tensor(self.Label[item])
return txt,label
def __len__(self):
return len(self.Data)
# 获取数据集
test = TestDataset()
print(test[2]) # 相当于调用了__getitem__(2)
print(test.__len__())
以上tuple返回,每次只返回一个样本,实际上,Dataset只负责数据的抽取,调用一次__getitem__只返回一个样本。如果希望批处理,还需呀同时进行shuffle和并行加速等操作,可选择DataLoader。格式如下:
data.DataLoader(
dataset=test,
batch_size=2,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
)
参数说明
- dataset:加载的数据集
- batch_size:批大小
- shuffle:是否将数据打乱
- sampler:样本抽样
- num_workers:使用多进程加载的进程数,0代表不使用多进程
- pin_memory:是否将数据保存在pin memory区,pin memory转化到GPU会更快
- drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃。
torchvision简介
torchvision有四个功能模块:model、datasets、transforms和utils。下面主要介绍使用 datasets的ImageFolder处理的自定义数据集,以及transforms对源数据进行预处理、增强等。
transforms
transforms提供了对PIL Images对象和Tensor对象的常用操作
- 对PIL Image的常见操作
- Scale/Resize:调整尺寸,长款比保持不变
- CenterCrop、ReandomCrop、RandomSizedCrop:裁剪图片,CenterCrop、ReandomCrop在crop时候是固定size,RandomSizedCrop则是random size的crop
RandomSizedCrop变换已弃用,RandomResizedCrop代替
- pad:填充
- ToTensor:把一个取值范围为[0,255]的PIL.Image转那位为Tensor。形状为(H,W,C)的Numpy.ndarray转化成形状为[H,W,C],取值范围为[0,1,0]的torch.FloatTensor
- RandomHorizontalFlip:图像随机水平翻转,翻转概率为0.5
- RandomVerticalFlip:图像随机垂直翻转
- ColorJitter:修改亮度、对比度和饱和度
- 对Tensor的常见操作
- Normalize:标准化,即,减均值,除以标准差
- ToPILImage:将Tensor转化为PILImage
如果要对数据集尽心多个操作,可以通过Compose将这些操作像管道一样拼接起来,类似于nn.Sequential如下:
transforms.Compose([
transforms.CenterCrop(10),
transforms.RandomCrop(20, padding=0),
transforms.ToTensor()
])
ImageFolder
当文件依据标签处于不停的文件下,如
——data
|——data1
|——001.jpg
|——002.jpg
|——003.jpg
|——data2
|——001.jpg
|——006.jpg
我们可以利用torchvision.datasets.ImageFolder来直接构造除dataset,代码如下
loader = torchvision.datasets.ImageFolder(path)
loader = data.DataLoader(dataset)
ImageFolder会自动将目录中的文件夹名自动转化成序列,当DataLoader载入时,标签自动就是整数序列了。
方法一
my_train = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
train_data = datasets.ImageFolder('./data/train_neg', my_train)
train_loader = data.DataLoader(train_data, batch_size=6, shuffle=False)
for i_batch, img in enumerate(train_loader):
方法二
my_train = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
train_data = datasets.ImageFolder('./data/train_neg', my_train)
train_loader = data.DataLoader(
dataset=train_data,
shuffle=False,
batch_size=6
)
for i_batch, img in enumerate(train_loader):