torchvision 库是服务于pytorch深度学习框架的,用来生成图片,视频数据集,和一些流行的模型类和预训练模型.
torchvision.datasets
所有数据集都是 torch.utils.data.dataset 的子类,也就是说,它们都实现了 __getitem__ 和 __len__ 方法。因此,它们都可以传递给 torch.utils.data.dataloader,后者可以使用 torch.multiprocessing workers 并行加载多个样本。例如:
imagenet_data = torchvision.datasets.ImageFolder('path/to/imagenet_root/')
data_loader = torch.utils.data.DataLoader(imagenet_data,
batch_size=4,
shuffle=True,
num_workers=args.nThreads)
可获得的数据集如下:
目录
所有数据集都有几乎相似的API,都有两个共同的参数:transform 和 target_transform 分别对 input 和 target 进行转换 。
MNIST
CLASS torchvision.datasets.MNIST
(root, train=True, transform=None, target_transform=None, download=False)
0-9手写数字 数据集。
Parameters: |
|
---|
Fashion-MNIST
CLASS torchvision.datasets.
FashionMNIST
(root, train=True, transform=None, target_transform=None, download=False)
10类衣服标签的数据集。
每个 training 和 test 示例的标签如下:
Label | Description |
---|---|
0 | T-shirt/top |
1 | Trouser |
2 | Pullover |
3 | Dress |
4 | Coat |
5 | Sandal |
6 | Shirt |
7 | Sneaker |
8 | Bag |
9 | Ankle boot |
KMNIST
CLASS torchvision.datasets.
KMNIST
(root, train=True, transform=None, target_transform=None, download=False)
手写日语片假名 数据集。
EMNIST
CLASS torchvision.datasets.
EMNIST
(root, split, **kwargs)
MNIST数据库来自更大的数据集,称为NIST特殊数据库19,其包含数字,大写和小写手写字母。 完整NIST数据集的变体,称为扩展MNIST(EMNIST),它遵循用于创建MNIST数据集的相同转换范例。
Parameters: |
|
---|
FakeData
CLASS torchvision.datasets.
FakeData
(size=1000, image_size=(3, 224, 224), num_classes=10, transform=None, target_transform=None, random_offset=0)
假数据集,返回随机生成的图像并将其作为PIL图像返回。
Parameters: |
---|
COCO
需要安装Coco API
COCO数据集的使用:https://www.cnblogs.com/q735613050/p/8969452.html
Captions
CLASS torchvision.datasets.
CocoCaptions
(root, annFile, transform=None, target_transform=None)
MS Coco Captions 数据集。
Parameters: |
---|
例子:
import torchvision.datasets as dset
import torchvision.transforms as transforms
cap = dset.CocoCaptions(root = 'dir where images are',
annFile = 'json annotation file',
transform=transforms.ToTensor())
print('Number of samples: ', len(cap))
img, target = cap[3] # load 4th sample
print("Image Size: ", img.size())
print(target)
# output:
Number of samples: 82783
Image Size: (3L, 427L, 640L)
[u'A plane emitting smoke stream flying over a mountain.',
u'A plane darts across a bright blue sky behind a mountain covered in snow',
u'A plane leaves a contrail above the snowy mountain top.',
u'A mountain that has a plane flying overheard in the distance.',
u'A mountain view with a plume of smoke in the background']
__getitem__
(index)
Parameters: | index (int) – Index |
---|---|
Returns: | Tuple (image, target) target 是 image 的标题列表。 |
Return type: | tuple |
Detection
CLASS torchvision.datasets.
CocoDetection
(root, annFile, transform=None, target_transform=None)
MS Coco Detaction 数据集。
__getitem__
(index)
Parameters: | index (int) – Index |
---|---|
Returns: | Tuple (image, target). target是coco.loadAnns返回的对象。 |
Return type: | tuple |
LSUN
CLASS torchvision.datasets.
LSUN
(root, classes='train', transform=None, target_transform=None)
Parameters: |
---|

ImageFolder
CLASS torchvision.datasets.
ImageFolder
(root, transform=None, target_transform=None, loader=<function default_loader>)
通用数据加载器,其中图像以这种方式排列:
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Parameters: |
---|
DatasetFolder
CLASS torchvision.datasets.
DatasetFolder
(root, loader, extensions, transform=None, target_transform=None)
通用数据加载器,其中样本以这种方式排列:
root/calss_x/xxx.ext
root/calss_x/xxy.ext
root/calss_x/xxz.ext
root/calss_y/123.ext
root/calss_y/nsdf3.ext
root/calss_y/asd32_.ext
Parameters: |
---|
Imagenet-12
这应该只使用 ImageFolder 数据集实现。ImageNet大规模视觉识别挑战(ILSVRC)数据集有1000个类别和120万个图像。 图像不需要在任何数据库中进行预处理或打包,但需要将验证图像移动到适当的子文件夹中。
CIFAR
CLASS torchvision.datasets.
CIFAR10
(root, train=True, transform=None, target_transform=None, download=False)
CIFAR10 数据集由10个类中的60000个32x32彩色图像组成,每个类有6000个图像。 有50000个训练图像和10000个测试图像。数据集分为五个训练 batch 和一个测试 batch ,每个 batch 有10000个图像。 测试 batch 包含来自每个类别的1000个随机选择的图像。训练 batch 以随机顺序包含剩余图像,但是一些训练 batch 可能包含来自一个类别的更多图像而不是另一个类别。 training batch包含来自每个 class 的5000个图像。
Parameters: |
---|
CLASS torchvision.datasets.
CIFAR100
(root, train=True, transform=None, target_transform=None, download=False)
CIFAR100 数据集与CIFAR-10类似,不同之处在于它有100个类,每个类包含600个图像。 每个类有500个训练图像和100个测试图像。 CIFAR-100中的100个类被分为20个超类。 每个图像都带有一个“精细”标签(它所属的类)和一个“粗”标签(它所属的超类)。
STL10
CLASS torchvision.datasets.
STL10
(root, split='train', transform=None, target_transform=None, download=False)
STL-10 数据集是用于开发无监督特征学习,深度学习,自学习学习算法的图像识别数据集。它的灵感来自CIFAR-10数据集,但有一些修改。特别地,每个类具有比CIFAR-10更少的标记训练示例,但是提供了非常大的一组未标记示例以在监督训练之前学习图像模型。 主要的挑战是利用未标记的数据(来自与标记数据相似但不同的分布)来构建有用的先验数据。 期望该数据集的更高分辨率(96x96)将使其成为开发更具可扩展性的无监督学习方法的具有挑战性的基准。
Parameters: |
---|
SVHN
CLASS torchvision.datasets.
SVHN
(root, split='train', transform=None, target_transform=None, download=False)
SVHN数据集(the Street View House Numbers (SVHN) 街景号码数据集)注意:SVHN数据集将标签10分配给数字0。但是,在此数据集中,我们将标签0分配给数字0以与PyTorch损失函数兼容,这些函数期望类标签在[0,C-1]范围内。
Parameters: |
---|
PhotoTour
CLASS torchvision.datasets.
PhotoTour
(root, name, train=True, transform=None, download=False)
数据集由1024 x 1024位图(.bmp)图像组成,每个图像包含16 x 16阵列的图像块。每个 patch 采样为64 x 64灰度,具有规范的比例和方向。关联的元数据文件 info.txt 包含匹配信息。 info.txt 的每一行对应一个单独的 patch, patch 在每个位图图像中从左到右,从上到下排序。 info.txt每行的第一个数字是从中采样该 patch 的3D点ID - 具有相同3D点ID的 patch 从相同的3D点投射到不同的图像中。 info.txt中的第二个数字对应于采样 patch 的图像,目前尚未使用。
__getitem__
(index)
Parameters: | index (int) – Index |
---|---|
Returns: | (data1, data2, matches) |
Return type: | tuple |
SBU
CLASS torchvision.datasets.
SBU
(root, transform=None, target_transform=None, download=True)
Im2Text:使用100万张标题照片描述图像。
Flickr
CLASS torchvision.datasets.
Flickr8k
(root, ann_file, transform=None, target_transform=None)
Parameters: |
---|
__getitem__
(index)
Parameters: | index (int) – Index |
---|---|
Returns: | Tuple (image, target). target is a list of captions(字幕) for the image. |
Return type: | tuple |
CLASS torchvision.datasets.
Flickr30k
(root, ann_file, transform=None, target_transform=None)
VOC
CLASS torchvision.datasets.
VOCSegmentation
(root, year='2012', image_set='train', download=False, transform=None, target_transform=None)
Parameters: |
---|
__getitem__
(index)
Parameters: | index (int) – Index |
---|---|
Returns: | (image, target) 其中 target 是 image segmentation(分割). |
Return type: |
CLASS torchvision.datasets.
VOCDetection
(root, year='2012', image_set='train', download=False, transform=None, target_transform=None)
__getitem__
(index)
Parameters: | index (int) – Index |
---|---|
Returns: | (image, target) 其中 target is a dictionary of the XML tree(是XML树的字典). |
Return type: | tuple |
Cityscapes
需要下载 cityscape。
CLASS torchvision.datasets.
Cityscapes
(root, split='train', mode='fine', target_type='instance', transform=None, target_transform=None)
Parameters: |
|
---|
例子
获取语义分割目标
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
target_type='semantic')
img, smnt = dataset[0]
获得多个目标
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
target_type=['instance', 'color', 'polygon'])
img, (inst, col, poly) = dataset[0]
在“coarse”集上验证
dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
target_type='semantic')
img, smnt = dataset[0]
__getitem__
(index)
Parameters: | index (int) – Index |
---|---|
Returns: | (image, target) 如果target_type是具有多个项目的列表,target是所有目标类型的元组。否则,如果target_type =“polygon”,则target是json对象,否则是图像分割。 |
Return type: | tuple |