上一篇博客简单介绍了MNIST数据集经典数据集MNIST(手写数字数据集)详解与可视化-优快云博客,对于大多数模型来说基本上都能在MNIST数据集上跑出比较好的结果,我刚开始接触MNIST的时候直接使用全连接层的网络结果也能得到90%以上的准确率。
今天再来介绍一下相比于MNIST较为复杂的CIFAR-10数据集,后续准备写模型训练的文章也基本上用CIFAR-10作为数据集。
数据集介绍
深度学习和机器学习领域中,可能最常见的挑战之一就是图像分类。对于这一任务,我们需要大量具有标签的数据集,来训练和验证我们的模型。在众多的公开可用的数据集中,有一个十分重要且广泛使用的是CIFAR-10。
CIFAR-10是由加拿大高级研究所(Canadian Institute for Advanced Research,简称CIFAR)发布的一个重要图像分类数据集。该数据集包含10个不同的类别,分别是:飞机,汽车,鸟类,猫,鹿,狗,青蛙,马,船,和卡车。这些类别均匀分布在数据集中,每个类别都有6000张图像。 所有的图像均为32x32像素的彩色图像,总共的图像数量为60000张。其中,50000张被作为训练数据,剩下的10000张则作为测试数据。
对于机器学习实践者来说,CIFAR-10是一个难度适中的数据集,其既不会使模型过于简单,也不会使训练过程变得过于复杂。 而对于研究者来说,CIFAR-10提供了一个公平的竞技场,用于比较各种图像分类算法。众多研究中,用到了CIFAR-10作为他们算法性能的评价标准。包括但不限于神经网络、卷积神经网络、K近邻算法、支持向量机和决策树等。
关于数据集的官方网页介绍https://www.cs.toronto.edu/~kriz/cifar.html
值得注意的是,这里每一个类别的图片一定是相互独立的,即每一张图片有且只对应一个标签,有小伙伴可能注意到类别中存在“automoble”和“truck”,在CIFAR10官方的说明网页里面也说明了这两个类别是互相独立的。原文“The classes are completely mutually exclusive. There is no overlap between automobiles and trucks. "Automobile" includes sedans, SUVs, things of that sort. "Truck" includes only big trucks. Neither includes pickup trucks.”
数据集下载
与MNIST相同,CIFAR同样提供了方便地代码下载接口,运行前需要导入的库:
# 数据转换方式
from torchvision import transforms
# CIFAR10下载接口
from torchvision.datasets import CIFAR10
# 可视化
import matplotlib.pyplot as plt
数据集下载
# 定义数据转换方式
my_trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 数据集下载
train_dataset = CIFAR10('D:/deep_learning/12_16/data/', train=True, transform=my_trans, download=True)
test_dataset = CIFAR10('D:/deep_learning/12_16/data/', train=False, transform=my_trans, download=True)
具体的参数介绍可以参考上一篇博客https://blog.youkuaiyun.com/weixin_57506268/article/details/135055111这里就不过多赘述
打印train_data的数据:
Number of datapoints: 50000
Root location: D:/deep_learning/12_16/data/
Split: Train
StandardTransform
Transform: Compose(
ToTensor()
Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
)
含有50,000张图片数据,符合官网描述格式
标签对应
从对数据集的描述来看我们知道,每张图片的标签应该是一个对应类别的字符串,那我们下载得到的数据是不是直接储存字符串呢?我们定义变量赋值到标签打印出来看看
# 取第一张图片数据
img,label = train_dataset[0]
print(label)
# 结果输出为 6
所以可见并不是直接储存字符串标签,而是换成了对应的序号,所以我们需要事先定义一个列表(或者是字典),列表的下标对应下载数据的标签,列表对应的值为数字标签对应的字符串类别
class_str = "airplane|automobile|bird|cat|deer|dog|frog|horse|ship|truck"
classes = class_str.split("|")
获得列表classes:['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
数据可视化
显示单张图片
# 取第五张图片数据
img,label = train_dataset[4]
plt.imshow(img.permute(1,2,0))
plt.title(classes[label])
plt.show()
由于原图像是的像素,所以不是特别清晰,但从标签和轮廓可以大概看出,图片属于汽车的类别
还需注意的是,不同于MNIST数据集的单通道灰度图像,CIFAR10数据集是三通道的彩色图片,在数据转换时经过ToTensor()变换的每张图片储存格式转换为[C,W,H],也就是将通道数放在像素宽与高的前面,这种数据格式是不能直接通过plt.imshow()显示的,需要将图片的格式转换回[W,H,C]的格式,把通道数放在第三维度,这就是在代码中需要进行img.permute(1,2,0)变换的原因 ,将原数据的第一维度变换到第三维度上
显示多张图片
# 显示前6张图片
for i in range(6):
img,label = train_dataset[i]
plt.subplot(2,3,i+1)
plt.imshow(img.permute(1,2,0))
plt.title(classes[label])
plt.show()
补充
可视化运行的时候会有以下提示:
表示将[0,1]范围的数据映射到[0,255]范围进行显示
欢迎大家讨论交流~