利用Tensorflow训练CIFAR-10数据集的训练部分源码详细分析

本文详细分析了利用Tensorflow训练CIFAR-10数据集的训练部分源码,涵盖了cifar10_input.py、cifar10.py和cifar_train.py三个关键文件的内容。实验结果显示,训练过程中屏幕输出了时间、步数、损失值和训练速率等信息,并通过Tensorboard进行了可视化,展示了参数值、训练图像、直方图和结构图等。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

利用Tensorflow训练CIFAR-10数据集的训练部分源码详细分析

实验环境

ubuntu16.04
python2.7
tensorflow1.4

代码

代码一共有三个文件,分别是cifar10_input.py, cifar10.py, cifar_train.py。首先是cifar10_input.py,代码如下:

# coding:utf-8
# 绝对引入
from __future__ import absolute_import
# 导入精确除法
from __future__ import division
# 导入print函数,print要使用括号
from __future__ import print_function

import os

from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

IMAGE_SIZE = 24
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
# 总类别数
NUM_CLASSES = 10

# 定义read.cifar10函数,将数据弄成图像形式
def read_cifar10(filename_queue):
    # 定义类
    class CIFAR10Record(object):
        pass
    result = CIFAR10Record()

    # 标签字节数
    label_bytes = 1
    # 图片的高
    result.height = 32
    # 图片的宽
    result.width = 32
    # 图片的深度
    result.depth = 3
    # 一张图片的字节数
    image_bytes = result.height * result.width * result.depth
    # 标签加图片的字节数
    record_bytes = label_bytes + image_bytes
    # 定义每次读多少字节
    reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
    # 读取文件队列名中的数据
    result.key, value = reader.read(filename_queue)

    # tf.decode_raw函数是将原来编码为字符串类型的变量重新变回来
    record_bytes = tf.decode_raw(value, tf.uint8)
    # tf.strided_slice(input_, begin, end)提取张量的一部分
    result.label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
    depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes], [label_bytes + image_bytes]),
                             [result.depth, result.width, result.height])
    # transpose函数是将矩阵进行转置操做,[0, 1, 2]中0表示高(深度),1表示行,2表示列。
    # [1, 2, 0]表示将原来的行作为高,列作为行,高作为列,对矩阵进行转置。
    # [1, 2, 0]刚好将矩阵变为了  行×列×深度  的矩阵。
    result.uint8image = tf.transpose(depth_major, [1, 2, 0])

    return result


# 定义数据增强函数
def distorted_inputs(data_dir, batch_size):
    # 读取数据
    filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in range(1,6)]
    for f in filenames:
        if not tf.gfile.Exists(f):
            raise ValueError('Failed to find file: ' + f)
    # 生成文件队列名,每一个文件就是一个data_batch
    filename_queue = tf.train.string_input_producer(filenames)
    # 调用cifar10_read函数,将数据变为图像形式
    read_input = read_cifar10(filename_queue)
    # 转换数据格式,从而得到训练使用的原始数据
    reshaped_image = tf.cast(read_input.uint8image, tf.float32)

    height = IMAGE_SIZE
    width = IMAGE_SIZE

    # 随机裁剪图片,从32×32变为24×24
    distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
    # 随机翻转图片。50%的概率翻转
    distorted_image = tf.image.random_flip_left_right(distorted_image)
    # 随机改变亮度,在(-63,63)区间内调整
    distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)
    # 随机调整对比度,在(0.2,1.8)区间内选择对比因子
    distorted_image = tf.image.random_contrast(distorted_image, lower=0.2, upper=1.8)

    # 返回具有零均值和单位范数的标准化图片,将像素值的大小限制到一个范围内,加速训练
    float_image = tf.image.per_image_standardization(distorted_image)

    # 设置输入图像的格式
    float_image.set_shape([height, width, 3])
    read_input.label.set_shape([1])

    # 队列中最少图像个数
    min_fraction_of_examples_in_queue = 0.4
    min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN*min_fraction_of_examples_in_queue)
    print('Filling queue with %d CIFAR images befor starting to train.'
          'This will take a a few minutes' % min_queue_examples)

    # 返回一个batch的图像和对应的标签,shuffle:打乱
    return _generate_image_and_label_batch(float_image, read_input.label,
                                           min_queue_examples, batch_size, shuffle=True)


# 定义构建图像和标签文件名队列的函数
def _generate_image_and_label_batch(image, label, min_queue_examples, batch_size, shuffle):
    # num_tfreads的值
    num_preprocess_tfreads = 16
    if shuffle:
        # tf.train.shuffle_batch函数中,设置num_threads的值大于1,则使用多个线程在tensor_list中读取文件。
        # capacity:队列中最大的元素数,一定要比min_after_dequeue的值大,决定了可以进行预处理操做元素的最大值。
        # min_after_dequeue:当一次出列操做完成后,取出的队列中元素的最小数量,用来定义混合级别。
        images, label_batch = tf.train.shuffle_batch([image, label], batch_size=batch_size,
                                                     num_threads=num_preprocess_tfreads,
                                                     capacity=min_queue_examples + 3 * batch_size,
                                                     min_after_dequeue=min_queue_examples)
    else:
        images, label_batch = tf.train.batch([image, label], batch_size=batch_size,
                                             num_threads=num_preprocess_tfreads,
                                             capacity=min_queue_examples + 3 * batch_size)
    # 在tensorboard中可视化训练图像
    tf.summary.image('images', images)
    # 返回图像和标签, tf.reshape返回的是一个张量
    return images, tf.reshape(label_batch, [batch_size])


# 创建CIFAR测试用的输入函数?
def inputs(eval_data, data_dir, batch_size):
    # eval_data是测试数据?
    if not eval_data:
        filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值