import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#载入数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
print(mnist)
#每个批次大小
batch_size = 100
# 计算一共有多少个批次
#n_batch = mnist.train.num//batch_size
#定义两个placeholder
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
#单层神经网络
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x, W) + b)
#二次代价函数
loss = tf.reduce_mean(tf.square(y - prediction))
#梯度下降法,优化器
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
#初始化变量
init = tf.global_variables_initializer()
#结果存放到一个布尔类型列表中
correct_prediction = tf.equal(tf.arg_max(y, 1), tf.arg_max(prediction, 1))#argmax返回一维张量中最大的值在哪个位置
#求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))#布尔型转化为floa32,true=1.0,false=0.0,再求平均值
with tf.Session() as sess:
sess.run(init)
for i in range(50): #训练次数
for batch in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
print(i, 'Iter' + ',Testing Accuracy' + str(acc))
手写数字识别
最新推荐文章于 2025-07-31 08:43:25 发布