中文文档:https://ptorch.com/docs/8/torch-torchvision
torchvision包括了目前流行的数据集,模型结构和常用的图片转换工具,是PyTorch中专门用来处理图像的库。这个包中有四个大类:
- torchvision.datasets
- torchvision.models
- torchvision.transforms
- torchvision.utils
所有数据集都是torch.utils.data.Dataset的子类, 即它们具有getitem和len实现方法。它们可以传递给torch.utils.data.DataLoader,工作人员可以使用torch.multiprocessing并行加载多个样本的数据。例如:
imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
所有的数据集都有几乎相似的API。他们都有两个共同的参数:transform 和 target_transform 分别转换输入和目标。
1 torchvision.datasets
torchvision.datasets 是用来进行数据加载的,这个包中提前处理好了很多图片数据集。
- MNIST
- COCO(用于图像标注和目标检测)(Captioning and Detection)
- LSUN Classification
- ImageFolder
- Imagenet-12
- CIFAR10 and CIFAR100
- STL10
- SVHN
- PhotoTour
1.1 MNIST
下载MNIST数据集,代码如下:
import torchvision
# 下载训练集
train_data = torchvision.datasets.MNIST(
root='path',
train=True,
transform=None,
target_transform=True,
download=False,
)
# 下载测试集
test_data = torchvision.datasets.MNIST(
root='path',
train=False,
transform=None,
target_transform=True,
download=False,
)
参数说明:
- root:数据集,存在于根目录processed/training.pt 和 processed/test.pt中。
- train:True = 训练集,False = 测试集
- download:如果为true,请从Internet下载数据集并将其放在根目录中。如果数据集已经下载,则不会再次下载。
transform:
接收PIL映像并返回转换版本的函数/变换。例如:transform.RandomCrop
target_transform:
一个接收目标并转换它的函数/变换。