利用Tensorflow训练CIFAR-10数据集的训练部分源码详细分析
实验环境
ubuntu16.04
python2.7
tensorflow1.4
代码
代码一共有三个文件,分别是cifar10_input.py, cifar10.py, cifar_train.py。首先是cifar10_input.py,代码如下:
# coding:utf-8
# 绝对引入
from __future__ import absolute_import
# 导入精确除法
from __future__ import division
# 导入print函数,print要使用括号
from __future__ import print_function
import os
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
IMAGE_SIZE = 24
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
# 总类别数
NUM_CLASSES = 10
# 定义read.cifar10函数,将数据弄成图像形式
def read_cifar10(filename_queue):
# 定义类
class CIFAR10Record(object):
pass
result = CIFAR10Record()
# 标签字节数
label_bytes = 1
# 图片的高
result.height = 32
# 图片的宽
result.width = 32
# 图片的深度
result.depth = 3
# 一张图片的字节数
image_bytes = result.height * result.width * result.depth
# 标签加图片的字节数
record_bytes = label_bytes + image_bytes
# 定义每次读多少字节
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
# 读取文件队列名中的数据
result.key, value = reader.read(filename_queue)
# tf.decode_raw函数是将原来编码为字符串类型的变量重新变回来
record_bytes = tf.decode_raw(value, tf.uint8)
# tf.strided_slice(input_, begin, end)提取张量的一部分
result.label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes], [label_bytes + image_bytes]),
[result.depth, result.width, result.height])
# transpose函数是将矩阵进行转置操做,[0, 1, 2]中0表示高(深度),1表示行,2表示列。
# [1, 2, 0]表示将原来的行作为高,列作为行,高作为列,对矩阵进行转置。
# [1, 2, 0]刚好将矩阵变为了 行×列×深度 的矩阵。
result.uint8image = tf.transpose(depth_major, [1, 2, 0])
return result
# 定义数据增强函数
def distorted_inputs(data_dir, batch_size):
# 读取数据
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in range(1,6)]
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
# 生成文件队列名,每一个文件就是一个data_batch
filename_queue = tf.train.string_input_producer(filenames)
# 调用cifar10_read函数,将数据变为图像形式
read_input = read_cifar10(filename_queue)
# 转换数据格式,从而得到训练使用的原始数据
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
height = IMAGE_SIZE
width = IMAGE_SIZE
# 随机裁剪图片,从32×32变为24×24
distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
# 随机翻转图片。50%的概率翻转
distorted_image = tf.image.random_flip_left_right(distorted_image)
# 随机改变亮度,在(-63,63)区间内调整
distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)
# 随机调整对比度,在(0.2,1.8)区间内选择对比因子
distorted_image = tf.image.random_contrast(distorted_image, lower=0.2, upper=1.8)
# 返回具有零均值和单位范数的标准化图片,将像素值的大小限制到一个范围内,加速训练
float_image = tf.image.per_image_standardization(distorted_image)
# 设置输入图像的格式
float_image.set_shape([height, width, 3])
read_input.label.set_shape([1])
# 队列中最少图像个数
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN*min_fraction_of_examples_in_queue)
print('Filling queue with %d CIFAR images befor starting to train.'
'This will take a a few minutes' % min_queue_examples)
# 返回一个batch的图像和对应的标签,shuffle:打乱
return _generate_image_and_label_batch(float_image, read_input.label,
min_queue_examples, batch_size, shuffle=True)
# 定义构建图像和标签文件名队列的函数
def _generate_image_and_label_batch(image, label, min_queue_examples, batch_size, shuffle):
# num_tfreads的值
num_preprocess_tfreads = 16
if shuffle:
# tf.train.shuffle_batch函数中,设置num_threads的值大于1,则使用多个线程在tensor_list中读取文件。
# capacity:队列中最大的元素数,一定要比min_after_dequeue的值大,决定了可以进行预处理操做元素的最大值。
# min_after_dequeue:当一次出列操做完成后,取出的队列中元素的最小数量,用来定义混合级别。
images, label_batch = tf.train.shuffle_batch([image, label], batch_size=batch_size,
num_threads=num_preprocess_tfreads,
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples)
else:
images, label_batch = tf.train.batch([image, label], batch_size=batch_size,
num_threads=num_preprocess_tfreads,
capacity=min_queue_examples + 3 * batch_size)
# 在tensorboard中可视化训练图像
tf.summary.image('images', images)
# 返回图像和标签, tf.reshape返回的是一个张量
return images, tf.reshape(label_batch, [batch_size])
# 创建CIFAR测试用的输入函数?
def inputs(eval_data, data_dir, batch_size):
# eval_data是测试数据?
if not eval_data:
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1