MNIST_whole_net_XXX

本文介绍了一个基于TensorFlow的三层神经网络模型在MNIST数据集上的应用案例。该模型包括输入层、隐藏层及输出层,并采用了指数衰减学习率、滑动平均等策略以提高模型的泛化能力。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

MNIST训练TensorFlow三层网络

首先是源代码:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# about MNIST dataset
INPUT_NODE = 784
OUTPUT_NODE = 10

# about the Network
LAYER1_NODE = 500
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 30000
MOVING_AVERAGE_DECAY = 0.99

# an assistant function
# Given input and parameters, generate forward propogation.
def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2):
    # if not using ExponentialMovingAverage class, directly use current as parm.
    if avg_class == None:
        layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
        return tf.matmul(layer1, weights2) + biases2
    # else use avg_class.average compute EMA, then use for forward propogation.
    else:
        layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1))+avg_class.average(biases1))
        return tf.matmul(layer1,avg_class.average(weights2))+avg_class.average(biases2)




# training processes
def train(mnist):
    x = tf.placeholder(tf.float32, [None, INPUT_NODE], name='x-input')
    #y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
    y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='add_1')

    # hidden layers parameters
    weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE,LAYER1_NODE], stddev = 0.1))
    biases1 = tf.Variable(tf.constant(0.1,shape=[LAYER1_NODE]))

    # output layers parameters
    weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE,OUTPUT_NODE], stddev = 0.1))
    biases2 = tf.Variable(tf.constant(0.1,shape=[OUTPUT_NODE]))

    # forward propogation without using EMA
    y = inference(x, None, weights1, biases1, weights2, biases2)

    # traning times
    global_step = tf.Variable(0, trainable=False)

    # given global_step, which will accelerate the process of early training
    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    # tf.trainable_variables() gives collection of GraphKeys.TRAINABLE_VARIABLES
    variables_averages_op = variable_averages.apply(tf.trainable_variables())

    # forward propogation with using EMA
    average_y = inference(x, variable_averages, weights1, biases1, weights2, biases2)

    # cross entropy loss function
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(y, tf.argmax(y_,1))
    cross_entropy_mean = tf.reduce_mean(cross_entropy)

    # l2 regularize loss
    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    regularization = regularizer(weights1) + regularizer(weights2)
    loss = cross_entropy_mean + regularization

    #exponential decay rate settings
    # global_step, current iterate step
    # mnist.train.num_examples / BATCH_SIZE, steps needed for exhausting all data
    learning_rate = tf.train.exponential_decay( LEARNING_RATE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY)

    # optimize loss function
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)

    # at one time refresh Weights and Exponential_Moving_Average
    # same as follow:
    # train_op = tf.group(train_step, variables_averages_op)
    with tf.control_dependencies([train_step, variables_averages_op]):
        train_op = tf.no_op(name='train')

    # examine results for true or negative.
    correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1))
    # correct precision in this batch
    accuracy = tf.reduce_mean( tf.cast(correct_prediction, tf.float32) )

    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        train_feed = {x: mnist.train.images, y:mnist.train.labels}
        validate_feed = {x: mnist.validation.images, y:mnist.validation.labels}
        test_feed = {x: mnist.test.images, y_: mnist.test.labels }

        print test_feed,"\n\n\n"
        print validate_feed,"\n\n\n"
        print train_feed,"\n\n\n"

        # iteratively train NN
        for i in range(TRAINING_STEPS):
            if i % 100 == 0:
                validate_acc = sess.run(accuracy, feed_dict = validate_feed)
                print("After %d training step(s), validation accuracy using average model is %g " % (i, validate_acc))

            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            sess.run(train_op, feed_dict={x: xs, y_: ys})

        # after training, validate final precision in testing dataset
        test_acc = sess.run(accuracy, feed_dict=test_feed)
        print("After %d training step(s), test accuracy using average model is %g" % (TRAINING_STEPS, test_acc) )

def main(argv=None):
    mnist = input_data.read_data_sets("/home/user9/DATA/MNIST_manual", one_hot = True)
    train(mnist)

# TensorFlow main app interaction, tf.app.run will call the main() function
if __name__ == '__main__':
    tf.app.run()

代码理解

  • 首先是整体概览

从这里可以看出网络只有三层:输入层、隐层、输出层:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

# Parameters
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500

把数据集按照每100个分成一组:

BATCH_SIZE = 100 

这里制定了指数下降学习率的参数(基础学习率,学习率的衰减率):

LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99

损失函数中权重正则项所占的lambda,这个值一般通过cross_validation来选取,使得正确率达到最优:

REGULARIZATION_RATE = 0.0001

设定总训练次数:

TRAINING_STEPS = 30000

滑动平均——可以看作训练过程中,对于权重的平均值池化操作。
使用滑动平均模型的训练过程得到的最终模型,在测试数据上具有更好的健壮性——即可以在wild数据上也可以取得好的效果。
滑动平均模型的衰减率:

MOVING_AVERAGE_DECAY = 0.99

构建前向网络的辅助函数(参数:输入,滑动平均类——可选,w1,b1,w2,b2. 可以支持选择是否使用滑动平均模型):

def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2)

然后是训练函数(需要输入数据库):

def train(mnist):

主函数,用于作为单文件运行的时候:

def main(argv=None):
    mnist = input_data.read_data_sets( "/home/user9/DATA/MNIST_manual", one_hot = True)
    train(mnist)

如果是单文件运行时,执行main();否则不执行main(),作为函数文件被调用。

if __name__ == '__main__':
    tf.app.run()
  • 然后看train(mnist)的内部。

接下来是我自己总结的简化版本(相当于无语法规范的伪代码,只是为了方便理解。):

train():
    x = placeholder()
    y_ = placeholder()
    [W1, B1, W2, B2] = random()


# 两种前向传播方式
    # y ——不采用滑动平均
    y = inference(avr_cls = None)


    # average_y ——采用滑动平均
        #为方便理解,这里把global_step写成cur_step,其实就是当前迭代的步数:一步输入一个batch
    cur_step = 0 
        # 建立 滑动平均函数(类)
    var_aver = new_aver_class(MA_decay, cur_step) 
        # 实现 滑动平均 的操作
    var_aver_op = var_aver -> apply(  {W1, B1, W2, B2} )
        # 输入了 滑动平均类,调用 "var_aver -> average()",从而采用 EMA 版本的前向传播
    average_y = inference(avr_cls = var_aver)


# 后向传播  
    # 平均交叉熵 loss项
    cross_entropy = softmax_cross_entropy(y, y_,1)
    cross_entropy_mean = mean( cross_entropy )

    # 权重 loss项
    regularizer = l2_regularizer( RGL_RATE )# 建立 正则函数
    regularization = regularizer( W1, W2 )

    # 总的 loss function
    loss = cross_entropy_mean + regularization


# 采用上述 指数下降的学习率 优化
    # 技巧:指数下降的学习率
    learning_rate = exponential_decay(
        LEARNING_RATE_BASE, 
        global_step, 
        mnist.train.num_examples / BATCH_SIZE, 
        LEARNING_RATE_DECAY)
    # 训练 操作
    train_step = Optimizer( learning_rate ).minimize( loss, global_step )


# 训练 和 滑动平均 操作 放在一块,定义成新的操作:train_op
    train_op <--[ train_step, var_aver_op]

# 定义 accuracy 运算,得到的是一个 batch上 的正确率
    accuracy = mean( corrct_prediction )

# 初始化 Session,开始训练
    with tf.Session() as sess:
        initialize all variables: W1,W2,B1,B2
        validate_feed = ...
        test_feed = ...

        # 迭代训练阶段
        for i in range(TR_STEPS):
            # 每1000次迭代 计算一次 accuracy
            validate_acc = sess.run( accuracy ,feed_dict = validate_feed )
            print validate_acc

            # 从 mnist 的 train 数据集上 选取 batch_size 个训练数据
            xs, ys = mnist.train.next_batch( BATCH_SIZE )
            # 开始运行 train_op
            sess.run( train_op , feed_dict = {x: xs; y_: ys})

        # 训练完成后,在测试集上计算 正确率
        test_acc = sess.run( accuracy , feed_dict = test_feed )
        print test_acc
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值