Tensorflow MNIST解析

本文介绍了如何使用TensorFlow处理MNIST数据集,从基础神经网络开始,逐步优化,包括设置指数衰减学习率。通过全局步骤变量和指数衰减策略调整学习率,以改善模型训练效果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

数据的导入

函数read_data_sets(“数据地址”)
导入后的数据函数自动分成了train,validation,test三个对象

#可以通过下面的方式查看大小
print (data.train.num_examples)
print (data.validation.num_examples)
print (data.test.num_examples)

导入后的3个对象可以使用API,next_batch(size)获得下一组数据, 其中size是下一组的大小。

x_cur,y_cur = data.train.next_batch(100)

利用神经元网络对MNIST进行识别

一个神经元网络,由于其复杂程度,存在了好多优化的可能性,因此,在下面的代码中,我也会首先构造一个最基本的神经元网络,然后,从多个方面对其进行优化。

  • 损失函数的定义
  • 隐藏层数和隐藏层节点个数
  • 学习率
  • 优化算法
  • 正则化

最基础版本

import input_data
import tensorflow as tf
batch_size = 100
hidden1_nodes = 200
# 输入节点
x = tf.placeholder(tf.float32,shape=(None,784))
y = tf.placeholder(tf.float32,shape=(None,10))
#权值&隐藏层
w1 = tf.Variable(tf.random_normal([784,hidden1_nodes],stddev=0.1))
w2 = tf.Variable(tf.random_normal([hidden1_nodes,10],stddev=0.1))
hidden = tf.nn.relu(tf.matmul(x,w1)+b1)
b1 = tf.Variable(tf.random_normal([hidden1_nodes],stddev=0.1))
b2 = tf.Variable(tf.random_normal([10],stddev=0.1))
y_predict = tf.nn.relu(tf.matmul(hidden,w2)+b2)
#损失函数
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_predict))
#训练函数
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
#测试函数
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_predict, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

#执行
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(5000):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys})
        if i%1000==0:
            print 'Phase'+str(i/10000+1)+':',sess.run(accuracy, feed_dict={x: mnist.test.images, y:mnist.test.labels})

运行结果如下:
这里写图片描述

设置指数衰减学习率

讲解两个概念,第一个是全局步骤global_step。这是一个变量,用来跟踪全局的步骤,不是训练的变量,只是一个累加器。
因此,我们可以用下面的代码来跟踪全局的步骤:

global_step = tf.Variable(0, name='global_step', trainable=False)
train_step = optimizer.minimize(loss, global_step=global_step)

第二个概念是具有指数衰减性质的学习率,其目的是防止学习率一直保持不变。因为开始的时候,由于参数初始化,因此学习率可以大一些,但到后面的训练阶段,参数已经接近最优解,可以将学习率减小,防止参数抖动。

learning_rate = tf.train.exponential_decay(0.1, global_step, 100, 0.96, staircase=True)

上面的代码中,0.1表示初始学习率,global_step表示当前的全局步骤,100表示衰减速度,0.96表示衰减速率,衰减速度表示每100轮训练后学习率乘以0.96。
完整的带指数衰减学习率的MNIST分类代码

import input_data
import tensorflow as tf
batch_size = 100
hidden1_nodes = 200
# 输入节点
x = tf.placeholder(tf.float32,shape=(None,784))
y = tf.placeholder(tf.float32,shape=(None,10))
#权值&隐藏层
w1 = tf.Variable(tf.random_normal([784,hidden1_nodes],stddev=0.1))
w2 = tf.Variable(tf.random_normal([hidden1_nodes,10],stddev=0.1))
hidden = tf.nn.relu(tf.matmul(x,w1)+b1)
b1 = tf.Variable(tf.random_normal([hidden1_nodes],stddev=0.1))
b2 = tf.Variable(tf.random_normal([10],stddev=0.1))
y_predict = tf.nn.relu(tf.matmul(hidden,w2)+b2)
#损失函数
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_predict))
#设置学习率
global_step = tf.Variable(0, name='global_step', trainable=False)    //全局的一个计数器
learning_rate = tf.train.exponential_dacay(0.1,global_step,0.96,staircase=True)
#训练函数
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy,global_step=global_step)
#测试函数
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_predict, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

#执行
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(5000):
        print 'Learning_rate:',sess.run(learning_rate),'Global_step:',sess.run(global_step)
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys})
        if i%1000==0:
            print 'Phase'+str(i/10000+1)+':',sess.run(accuracy, feed_dict={x: mnist.test.images, y:mnist.test.labels})
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值