第二章 数据处理篇:transforms

本文详细介绍了PyTorch中的transforms模块,用于图像数据预处理,包括ToTensor()、Normalize()、Resize()、RandomCrop()、RandomHorizontalFlip等方法,以及如何组合使用这些方法进行图像变换,以适应模型训练需求。

教程参考:
https://pytorch.org/tutorials/
https://github.com/TingsongYu/PyTorch_Tutorial
https://github.com/yunjey/pytorch-tutorial
详细的transform的使用样例可以参考:ILLUSTRATION OF TRANSFORMS


为什么要使用transforms

你得到的原始数据,可能并不是你期望的用于模型训练的数据的形式,比如数据中图像的大小不同、数据的格式不对。这时就需要你对数据进行统一的处理,torchvision.transforms就提供了一些帮助我们进行数据处理的简易手段。

在pytorch官方教程最开始,给了这样一个示例。
示例中使用自带的datasets:FashionMNIST,为了便于训练,对于原始数据和label分别使用了transform的方法。
对于数据本身,使用的方法是 ToTensor(),
对于标签,使用的方法是one-hot。
在后面的部分我们会详细介绍一下不同的transform方法。

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

torchvision.transforms中提供了多个方法,并且这些方法可以使用Compose进行连接,并按顺序执行。其中的大部分transforms方法都可以接受PIL图像和tensor图像作为输入,当然也有一部分在输入上有限制。

transforms方法举例

我们使用opencv读入一张cifar10中的图片作为例子,并将其通道从BGR转为RGB通道。使用opencv读入的图片,为numpy.ndarray格式。下图是我们的例子,一个类别为airplane的图像。
在这里插入图片描述

ToTensor()

ToTensor()方法可以把一个PIL图像或者numpy.ndarray数据转成FloatTensor的形式,并且将图像规范化到0和1之间。
更细致地来说,它会把一共PIL图像,或者范围在[0,255]的大小为(HxWxC)的numpy.ndarray转成一个大小为(CxHxW)的范围在[0.0,1.0]的floattensor。ndarray数据的dtype必须是np.uint8。
在这里插入图片描述

使用ToTensor()方法对我们的img进行处理,可以看到它原本为uint8的ndarray,变成了float32的tensor,它的形状从(32, 32, 3)转为(3, 32, 32),并且它的像素值的大小从51 到 255被转变为0.2到1.0。

我们也可以将图像读取为PIL Image的形式,并使用同样的方法处理。得到的结果是完全相同的。
在这里插入图片描述

Normalize()

Normalize()方法可以把一个tensor数据进行归一化/标准化处理。在使用时,需要你提供数据的均值和方差,Normalize()会对输入数据的每一个通道进行归一化处理。使用的方法是:
o u t p u t [ c h a n n e l ] = i n p u t [ c h a n n e l ] − m e a n [ c h a n n e l ] s t d [ c h a n n e l ] output[channel] = \frac{input[channel] - mean[channel]}{std[channel]} output[channel]=

### 处理《鱼书》第三章 MNIST 源码 404 错误 当遇到MNIST数据集下载失败或返回404错误时,可以采取多种方法来解决问题。一种常见的原因是原始下载链接失效或是网络连接不稳定造成的。 对于《深度学习入门:基于Python的理论与实现》,即通常所说的《鱼书》,如果按照书中指示尝试获取MNIST数据集却遭遇了404错误,则建议采用如下替代方案: #### 替代下载方式一:更改默认的数据源URL 可以通过修改代码中的资源列表,指向其他可靠的镜像站点以解决官方网址不可达的问题。例如,在某些版本的`chainer.datasets.get_mnist()`函数内部定义了一个名为`resources`的变量用于指定文件位置及其对应的MD5校验值[^1]。因此,可以在本地环境中调整这些设置,确保能够顺利取得所需资料。 ```python import os from urllib.request import urlretrieve import gzip, shutil def download_and_extract(url, filename): if not os.path.exists(filename[:-3]): print(f'Downloading {filename}...') urlretrieve(url, filename) with gzip.open(filename, 'rb') as f_in: with open(filename[:-3], 'wb') as f_out: shutil.copyfileobj(f_in, f_out) os.remove(filename) # 删除压缩包 if __name__ == "__main__": resources = [ ("https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz", "train-images.idx3-ubyte"), ("https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz", "train-labels.idx1-ubyte"), ("https://ososci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz", "t10k-images.idx3-ubyte"), ("https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz", "t10k-labels.idx1-ubyte") ] for resource_url, file_name in resources: download_and_extract(resource_url, file_name + ".gz") ``` 此脚本会自动检测是否存在未解压的目标文件;如果没有找到,则先下载相应的`.gz`格式存档再进行解压缩处理,并最终移除临时产生的压缩文档副本。 #### 替代下载方式二:利用第三方库辅助加载 除了直接操作HTTP请求外,还可以借助于专门设计用来简化机器学习流程管理工作的工具——比如`tensorflow_datasets`或者`torchvision`等框架自带的功能模块来完成相同任务。这类高级接口不仅提供了更便捷的方法去访问流行公开数据集合,而且往往内置有缓存机制以及多线程加速特性,有助于提高整体效率并减少潜在的风险因素。 例如使用TensorFlow Datasets: ```python import tensorflow_datasets as tfds # 加载MNIST训练集和测试集 dataset_train = tfds.load('mnist', split='train') dataset_test = tfds.load('mnist', split='test') for example in dataset_train.take(1): # 取第一个样本查看结构 image, label = example["image"], example["label"] print(image.shape, label.numpy()) ``` 又或者是通过PyTorch Vision快速部署: ```python from torchvision import datasets, transforms transform=transforms.Compose([ transforms.ToTensor(), ]) train_dataset = datasets.MNIST('./data/', train=True, download=True, transform=transform) test_dataset = datasets.MNIST('./data/', train=False, transform=transform) print(len(train_dataset), len(test_dataset)) ``` 这两种途径都能有效规避因原生API局限所带来的困扰,同时也为后续开发工作奠定了坚实的基础。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值