Tensorflow cifar模型源码
model
# -*- coding: utf-8 -*-
# cifar模型:图像分类
#
# Author: Igor
import gzip
import os
import re
import sys
import tarfile
import urllib.request
import tensorflow as tf
from TensorFlow.cifar import cifar10_input
FLAGS = tf.app.flags.FLAGS
# 基础的模型参数
tf.app.flags.DEFINE_integer('batch_size', 128,
"number of images to process in a batch")
tf.app.flags.DEFINE_string('data_dir', 'data/',
"Path to the CIFAR-10 data directory")
# 描述CIFAR-10数据集的全局常量
IMAGE_SIZE = cifar10_input.IMAGE_SIZE
NUM_CLASS = cifar10_input.NUM_CLASSES
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
# 训练过程的全局常量
MOVING_AVERAGE_DECAY = 0.9999 # 移动平均衰减
NUM_EPOCHS_PER_DECAY = 350.0 # 当学习速率开始下降的(期数)Epochs
LEARNING_RATE_DECAY_FACTOR = 0.1 # 学习速率衰减因子
INITIAL_LEARNING_RATE = 0.1 # 初始化学习速率
TOWER_NAME = 'tower'
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
def distored_inputs():
'''
利用Reader ops为训练构建CIFAR数据集的输入
:return:
images:Images 4D tensor of [batch_size,IMAGE_SIZE,IMAGE_SIZE,3]
labels:Labels 1D [batch_size]
'''
if not FLAGS.data_dir:
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
return cifar10_input.distored_inputs(data_dir, FLAGS.batch_size)
def inputs(eval_data):
'''
利用Reader ops为评价CIFAR构建输入
:param eval_data:是否利用测试集
:return:
images:Images 4D tensor of [batch_size,IMAGE_SIZE,IMAGE_SIZE,3]
labels:Labels 1D [batch_size]
'''
if not FLAGS.data_dir:
raise ValueError('Please supply a data_dir')
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
return cifar10_input.inputs(eval_data, data_dir, FLAGS.batch_size)
def _variable_on_cpu(name, shape, initializer):