TensorFlow实战笔记之(5):卷积神经网络 实现CIFAR-10数据集分类

本文介绍了使用TensorFlow实现CIFAR-10数据集分类的卷积神经网络(CNN)方法,详细讲解了CIFAR-10数据集的读取和数据增强,探讨了卷积神经网络的设计,包括结构、BN层的原理和使用,并展示了训练过程及结果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、CIFAR-10数据集

1.简介

CIFAR-10是一个经典的数据集,包含60000张RGB 32x32像素的图像,其中训练集50000张,测试集10000张。CIFAR-10即标注为10类,每一类图片6000张,如下图。

                                        

数据集下载地址:https://www.cs.toronto.edu/~kriz/cifar.html

下载下来的数据集中包含下面的一些文件:

                           

这里共有六个主要文件,其中五个训练数据文件,文件名为:data_batch_1.bin,…, data_batch_5.bin,一个测试数据文件,名为test_batch.bin。每个文件都是用cPickle生成的python“pickled”对象。下面是python3例程,可以打开此类文件并返回字典:

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

以这种方式加载的每个批处理文件都包含一个字典,包含以下元素:

  • data -- 10000x3072 的uint8s格式numpy数组。数组的每一行存储一个32x32的彩色图像,按顺序包含红色、绿色和蓝色三个通道的值,因此每行的长度为32x32x3=3072。图像按行进行存储,如数组的前32个值是图像第一行的红色通道值。
  • labels -- 取值为0-9的包含10000个数字的list。索引i处的数字表示数组data中第i个图像的标签。

数据集中还有另一个名为batches.meta.txt的文件,它也包含一个python字典对象: 

  • label_names -- 一个10元素的list,为上述labels的具体名称。如, label_names[0] == "airplane", label_names[1] == "automobile" 等。

2.读取数据集

为了读取CIFAR-10中的图像数据,你可以像上面的例程那样自己写一个读取数据的程序,比如这样

另一种方法是使用tfds.load,即

dataset = tfds.load(name='cifar10', split=split)

具体见https://www.tensorflow.org/datasets/api_docs/python/tfds/load

当然,最傻瓜的方法还是直接使用官方提供的代码,当然官方代码里面也是用tfds.load实现的。

代码下载地址:https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10,包含以下内容:

                               

读取数据集的代码在cifar10_input.py文件中,这个文件里的代码会不定时的更新,所以如果直接使用以前别人博客或书里的调用例程可能会出现一些错误,目前cifar10_input.py文件里主要有以下内容:

_get_images_labels():获取图像以及标签数据。

DataPreprocessor():数据预处理类。

distorted_inputs():构建训练数据并进行预处理。

inputs():构建测试数据并进行预处理(也可以用在训练集上)。

调用方法:

from cifar10 import cifar10_input

# 训练集
images_train, labels_train = cifar10_input.distorted_inputs(batch_size=batch_size)

# 测试集
images_test, labels_test = cifar10_input.inputs(eval_data=True, batch_size=batch_size)

需要注意的是,这里对训练集数据进行了数据增强(Data Augmentation),具体可查看cifar10_input里的DataPreprocessor类,其中的数据增强操作包括随机剪切一块24×24大小的图片(tf.random_crop),随机的水平翻转(tf.image.random_flip_left_right),设置随机的亮度和对比度(tf.image.random_brightness、tf.image.random_contrast),以及对数据进行标准化(tf.image.per_image_standardization,减去均值并除以像素的方差,使模型对图像的动态范围变化不敏感)。如下图:

                            &n

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值