一、CIFAR-10数据集
1.简介
CIFAR-10是一个经典的数据集,包含60000张RGB 32x32像素的图像,其中训练集50000张,测试集10000张。CIFAR-10即标注为10类,每一类图片6000张,如下图。
数据集下载地址:https://www.cs.toronto.edu/~kriz/cifar.html
下载下来的数据集中包含下面的一些文件:
这里共有六个主要文件,其中五个训练数据文件,文件名为:data_batch_1.bin,…, data_batch_5.bin,一个测试数据文件,名为test_batch.bin。每个文件都是用cPickle生成的python“pickled”对象。下面是python3例程,可以打开此类文件并返回字典:
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
以这种方式加载的每个批处理文件都包含一个字典,包含以下元素:
- data -- 10000x3072 的uint8s格式numpy数组。数组的每一行存储一个32x32的彩色图像,按顺序包含红色、绿色和蓝色三个通道的值,因此每行的长度为32x32x3=3072。图像按行进行存储,如数组的前32个值是图像第一行的红色通道值。
- labels -- 取值为0-9的包含10000个数字的list。索引i处的数字表示数组data中第i个图像的标签。
数据集中还有另一个名为batches.meta.txt的文件,它也包含一个python字典对象:
- label_names -- 一个10元素的list,为上述labels的具体名称。如, label_names[0] == "airplane", label_names[1] == "automobile" 等。
2.读取数据集
为了读取CIFAR-10中的图像数据,你可以像上面的例程那样自己写一个读取数据的程序,比如这样。
另一种方法是使用tfds.load,即
dataset = tfds.load(name='cifar10', split=split)
具体见https://www.tensorflow.org/datasets/api_docs/python/tfds/load。
当然,最傻瓜的方法还是直接使用官方提供的代码,当然官方代码里面也是用tfds.load实现的。
代码下载地址:https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10,包含以下内容:
读取数据集的代码在cifar10_input.py文件中,这个文件里的代码会不定时的更新,所以如果直接使用以前别人博客或书里的调用例程可能会出现一些错误,目前cifar10_input.py文件里主要有以下内容:
_get_images_labels():获取图像以及标签数据。
DataPreprocessor():数据预处理类。
distorted_inputs():构建训练数据并进行预处理。
inputs():构建测试数据并进行预处理(也可以用在训练集上)。
调用方法:
from cifar10 import cifar10_input
# 训练集
images_train, labels_train = cifar10_input.distorted_inputs(batch_size=batch_size)
# 测试集
images_test, labels_test = cifar10_input.inputs(eval_data=True, batch_size=batch_size)
需要注意的是,这里对训练集数据进行了数据增强(Data Augmentation),具体可查看cifar10_input里的DataPreprocessor类,其中的数据增强操作包括随机剪切一块24×24大小的图片(tf.random_crop),随机的水平翻转(tf.image.random_flip_left_right),设置随机的亮度和对比度(tf.image.random_brightness、tf.image.random_contrast),以及对数据进行标准化(tf.image.per_image_standardization,减去均值并除以像素的方差,使模型对图像的动态范围变化不敏感)。如下图:
&n