transforms_tutorial 主要介绍了PyTorch中的transforms模块,它可以帮助我们在数据预处理时进行多种变换操作。
首先,它介绍了基本的图像变换,如裁剪、旋转、翻转等。然后,它介绍了随机变换,如随机裁剪、扭曲、亮度、对比度等。这些随机变换使得我们可以扩展数据集,从而提高模型的泛化能力。
接着,教程介绍了如何将transforms与PyTorch的数据加载器Dataset和DataLoader结合起来使用。这使得我们可以在加载数据时对它们进行变换。
最后介绍了如何自定义transforms。我们可以编写自己的函数来实现特定的变换,从而更好地满足我们的需求。
转换
数据并不总是以最终处理好的形式出现,以便于进行机器学习算法的训练。我们使用 transform
来对数据进行一些操作,使其适合于训练。
所有的 TorchVision 数据集都有两个参数—— transform
用于修改特征, target_transform
用于修改标签,这两个参数可以接受包含转换逻辑的可调用函数。torchvision.transforms 模块提供了多个常用的转换。
FashionMNIST 中的特征是 PIL Image 格式,标签是整数。对于训练,我们需要将特征转换为归一化的张量,将标签转换为独热编码张量。为了进行这些转换,我们使用 ToTensor
和 Lambda
。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
# 加载FashionMNIST数据集,对特征进行 ToTensor() 转换,对标签进行 Lambda() 转换
ds = datasets.FashionMNIST(
root="~/data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
ToTensor()
ToTensor() 方法将PIL图像或NumPy ndarray
转换为 FloatTensor
,并将图像的像素强度值缩放到 [0., 1.] 范围内。
Lambda Transforms
Lambda transforms应用任何用户定义的lambda函数。这里,我们定义一个函数将整数转换为one-hot编码张量。它首先创建一个大小为10(数据集中的标签数)的零张量,并调用 scatter_ 将值 value=1
分配给标签y
指定的索引。
target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))