torchvision
是 PyTorch 的一个官方库,主要用于处理计算机视觉任务。提供了许多常用的数据集、模型架构、图像转换等功能,使得计算机视觉任务的开发变得更加高效和便捷。以下是对 torchvision
主要功能的详细介绍:
1. 数据集(Datasets)
torchvision
提供了许多常用的计算机视觉数据集,如 CIFAR-10、MNIST、ImageNet 等。这些数据集可以直接通过 torchvision.datasets
模块加载。
示例:加载 CIFAR-10 数据集
from torchvision import datasets
from torch.utils.data import DataLoader
# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True)
# 使用 DataLoader 加载数据
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
2. 图像转换(Transforms)
torchvision.transforms
模块提供了许多常用的图像转换操作,如裁剪、缩放、旋转、翻转等。