MNIST训练TensorFlow三层网络
首先是源代码:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# about MNIST dataset
INPUT_NODE = 784
OUTPUT_NODE = 10
# about the Network
LAYER1_NODE = 500
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 30000
MOVING_AVERAGE_DECAY = 0.99
# an assistant function
# Given input and parameters, generate forward propogation.
def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2):
# if not using ExponentialMovingAverage class, directly use current as parm.
if avg_class == None:
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
return tf.matmul(layer1, weights2) + biases2
# else use avg_class.average compute EMA, then use for forward propogation.
else:
layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1))+avg_class.average(biases1))
return tf.matmul(layer1,avg_class.average(weights2))+avg_class.average(biases2)
# training processes
def train(mnist):
x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
#y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='add_1')
# hidden layers parameters
weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE,LAYER1_NODE], stddev = 0.1))
biases1 = tf.Variable(tf.constant(0.1,shape=[LAYER1_NODE]))
# output layers parameters
weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE,OUTPUT_NODE], stddev = 0.1))
biases2 = tf.Variable(tf.constant(0.1,shape=[OUTPUT_NODE]))
# forward propogation without using EMA
y = inference(x, None, weights1, biases1, weights2, biases2)
# traning times
global_step = tf.Variable(0, trainable=False)
# given global_step, which will accelerate the process of early training
variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
# tf.trainable_variables() gives collection of GraphKeys.TRAINABLE_VARIABLES
variables_averages_op = variable_averages.apply(tf.trainable_variables())
# forward propogation with using EMA
average_y = inference(x, variable_averages, weights1, biases1, weights2, biases2)
# cross entropy loss function
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(y, tf.argmax(y_,1))
cross_entropy_mean = tf.reduce_mean(cross_entropy)
# l2 regularize loss
regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
regularization = regularizer(weights1) + regularizer(weights2)
loss = cross_entropy_mean + regularization
#exponential decay rate settings
# global_step, current iterate step
# mnist.train.num_examples / BATCH_SIZE, steps needed for exhausting all data
learning_rate = tf.train.exponential_decay( LEARNING_RATE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY)
# optimize loss function
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
# at one time refresh Weights and Exponential_Moving_Average
# same as follow:
# train_op = tf.group(train_step, variables_averages_op)
with tf.control_dependencies([train_step, variables_averages_op]):
train_op = tf.no_op(name='train')
# examine results for true or negative.
correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1))
# correct precision in this batch
accuracy = tf.reduce_mean( tf.cast(correct_prediction, tf.float32) )
with tf.Session() as sess:
tf.initialize_all_variables().run()
train_feed = {x: mnist.train.images, y:mnist.train.labels}
validate_feed = {x: mnist.validation.images, y:mnist.validation.labels}
test_feed = {x: mnist.test.images, y_: mnist.test.labels }
print test_feed,"\n\n\n"
print validate_feed,"\n\n\n"
print train_feed,"\n\n\n"
# iteratively train NN
for i in range(TRAINING_STEPS):
if i % 100 == 0:
validate_acc = sess.run(accuracy, feed_dict = validate_feed)
print("After %d training step(s), validation accuracy using average model is %g " % (i, validate_acc))
xs, ys = mnist.train.next_batch(BATCH_SIZE)
sess.run(train_op, feed_dict={x: xs, y_: ys})
# after training, validate final precision in testing dataset
test_acc = sess.run(accuracy, feed_dict=test_feed)
print("After %d training step(s), test accuracy using average model is %g" % (TRAINING_STEPS, test_acc) )
def main(argv=None):
mnist = input_data.read_data_sets("/home/user9/DATA/MNIST_manual", one_hot = True)
train(mnist)
# TensorFlow main app interaction, tf.app.run will call the main() function
if __name__ == '__main__':
tf.app.run()
代码理解
- 首先是整体概览
从这里可以看出网络只有三层:输入层、隐层、输出层:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# Parameters
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500
把数据集按照每100个分成一组:
BATCH_SIZE = 100
这里制定了指数下降学习率的参数(基础学习率,学习率的衰减率):
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
损失函数中权重正则项所占的lambda,这个值一般通过cross_validation来选取,使得正确率达到最优:
REGULARIZATION_RATE = 0.0001
设定总训练次数:
TRAINING_STEPS = 30000
滑动平均——可以看作训练过程中,对于权重的平均值池化操作。
使用滑动平均模型的训练过程得到的最终模型,在测试数据上具有更好的健壮性——即可以在wild数据上也可以取得好的效果。
滑动平均模型的衰减率:
MOVING_AVERAGE_DECAY = 0.99
构建前向网络的辅助函数(参数:输入,滑动平均类——可选,w1,b1,w2,b2. 可以支持选择是否使用滑动平均模型):
def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2)
然后是训练函数(需要输入数据库):
def train(mnist):
主函数,用于作为单文件运行的时候:
def main(argv=None):
mnist = input_data.read_data_sets( "/home/user9/DATA/MNIST_manual", one_hot = True)
train(mnist)
如果是单文件运行时,执行main();否则不执行main(),作为函数文件被调用。
if __name__ == '__main__':
tf.app.run()
- 然后看train(mnist)的内部。
接下来是我自己总结的简化版本(相当于无语法规范的伪代码,只是为了方便理解。):
train():
x = placeholder()
y_ = placeholder()
[W1, B1, W2, B2] = random()
# 两种前向传播方式
# y ——不采用滑动平均
y = inference(avr_cls = None)
# average_y ——采用滑动平均
#为方便理解,这里把global_step写成cur_step,其实就是当前迭代的步数:一步输入一个batch
cur_step = 0
# 建立 滑动平均函数(类)
var_aver = new_aver_class(MA_decay, cur_step)
# 实现 滑动平均 的操作
var_aver_op = var_aver -> apply( {W1, B1, W2, B2} )
# 输入了 滑动平均类,调用 "var_aver -> average()",从而采用 EMA 版本的前向传播
average_y = inference(avr_cls = var_aver)
# 后向传播
# 平均交叉熵 loss项
cross_entropy = softmax_cross_entropy(y, y_,1)
cross_entropy_mean = mean( cross_entropy )
# 权重 loss项
regularizer = l2_regularizer( RGL_RATE )# 建立 正则函数
regularization = regularizer( W1, W2 )
# 总的 loss function
loss = cross_entropy_mean + regularization
# 采用上述 指数下降的学习率 优化
# 技巧:指数下降的学习率
learning_rate = exponential_decay(
LEARNING_RATE_BASE,
global_step,
mnist.train.num_examples / BATCH_SIZE,
LEARNING_RATE_DECAY)
# 训练 操作
train_step = Optimizer( learning_rate ).minimize( loss, global_step )
# 训练 和 滑动平均 操作 放在一块,定义成新的操作:train_op
train_op <--[ train_step, var_aver_op]
# 定义 accuracy 运算,得到的是一个 batch上 的正确率
accuracy = mean( corrct_prediction )
# 初始化 Session,开始训练
with tf.Session() as sess:
initialize all variables: W1,W2,B1,B2
validate_feed = ...
test_feed = ...
# 迭代训练阶段
for i in range(TR_STEPS):
# 每1000次迭代 计算一次 accuracy
validate_acc = sess.run( accuracy ,feed_dict = validate_feed )
print validate_acc
# 从 mnist 的 train 数据集上 选取 batch_size 个训练数据
xs, ys = mnist.train.next_batch( BATCH_SIZE )
# 开始运行 train_op
sess.run( train_op , feed_dict = {x: xs; y_: ys})
# 训练完成后,在测试集上计算 正确率
test_acc = sess.run( accuracy , feed_dict = test_feed )
print test_acc