1. torchvision.datasets模块提供很多内置数据集,方便下载和调用
比如
import torchvision
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
imagenet_data1 = torchvision.datasets.CIFAR10('path/to/images_root/')
imagenet_data2 = torchvision.datasets.FashionMNIST('path/to/images_root/')
imagenet_data3 = torchvision.datasets.Kitti('./', download=True)
imagenet_data4 = torchvision.datasets.Sintel('./', download=True)
包含了有很多数据集,包括分类,检测,分割,光流等数据集,具体可以查看网页:
[DATASETS]https://pytorch.org/vision/0.13/datasets.html
2. torchvision.datasets是torch.utils.data.Dataset的一个子类
因此包含__getitem__ 和 __len__等方法,也可以传递给 torch.utils.data.DataLoader 方便多线程采样数据。
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)