1、下载CIFAR10(训练集、测试集)
import torchvision
train_set = torchvision.datasets.CIFAR10(root='./dataset', train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root='./dataset', train=False, download=True)
1、查看测试集中第一个数据的信息:
print(test_set[0])
结果为:
前半部分是图片信息,后半部分的数字 “3” 是类别。
2、查看测试集中所有类别:
print(test_set.classes)
结果为:
3、另一种查看图片和标签的方式:
img, target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
结果为:
展示一下图片:
img.show()
2、将Dataset 与 transforms进行联动
1、首先需要将数据转换为tensor的格式:
dataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), ])
完整代码如下:
import torchvision
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
train_set = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=dataset_transform, download=True)
2、将数据集用tensorboard显示:
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
train_set = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=dataset_transform, download=True)
writer = SummaryWriter('P8')
for i in range(10):
img, target = test_set[i]
writer.add_image('test_set', img, i)
writer.close()
结果如下: