MNISTone----用的softmax

本文介绍了一个基于TensorFlow的手写数字识别系统实现过程。该系统利用MNIST数据集进行训练,采用简单的线性模型结合softmax函数进行分类,并通过梯度下降法优化交叉熵损失函数。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

# -*- coding: UTF-8 -*- 
'''
Created on 2017年12月8日
'''
#以下两句用于下载数据
import tensorflow.examples.tutorials.mnist.input_data as input_data
import tensorflow as tf  
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)     #下载并加载mnist数据

#输入输出占位符
x = tf.placeholder(tf.float32,[None, 784]) #图像输入向量,占位符,每一个sample都是784维,none表示可以有任意个sample 
y_ = tf.placeholder("float", [None,10])  #占位符,每一个sample都是10维,因为是one_hot
#参数
W = tf.Variable(tf.zeros([784,10]))  #权重,初始化值为全零,变量
b = tf.Variable(tf.zeros([10]))  #偏置,初始化值为全零,变量
#进行模型建立及计算,y是预测,y_ 是实际  
y = tf.nn.softmax(tf.matmul(x,W) + b)    

#计算交叉熵  
cross_entropy = -tf.reduce_sum(y_*tf.log(y+1e-10))  
tf.scalar_summary('cross_entropy',cross_entropy) 
#接下来使用BP算法来进行微调,以0.01的学习速率,使用的是简单的梯度下降算法----记住,这是一个优化算子
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  


#上面设置好了模型,添加初始化创建变量的操作  
init = tf.initialize_all_variables()  
#启动创建的模型,并初始化变量  
sess = tf.Session()  
sess.run(init)    #init也是操作
merged = tf.merge_all_summaries() #collect the tf.xxxxx_summary  
writer = tf.train.SummaryWriter('/home/tensorBoardLog/MNISTone',sess.graph)   
#开始训练模型,循环训练1000次  
for i in range(1000):  
    #随机抓取训练数据中的100个批处理数据点  
    batch_xs, batch_ys = mnist.train.next_batch(100)      #mnist.train #这里边有疑问:next_batch is a method of the DataSet class
    #https://stackoverflow.com/questions/40368697/where-does-next-batch-in-the-tensorflow-tutorial-batch-xs-batch-ys-mnist-trai
    #可以在github上看到
    summary,loss, _= sess.run([merged, cross_entropy, train_step], feed_dict={x:batch_xs,y_:batch_ys})  #train_step是一个操作,step表示每一步
    #注意是操作;给模型必要的输入,以及必要的操作指示
    writer.add_summary(summary, i)
    print('range: %04d, loss = %-9f' % (i+1, loss))

''''' 进行模型评估 '''  
#判断预测标签和实际标签是否匹配  
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))   #1表示在1轴上,0轴表示的是样本index
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))  #tf.cast是类型转换函数
#计算所学习到的模型在测试数据集上面的正确率  
print( sess.run(accuracy, feed_dict={x:mnist.test.images, y_:mnist.test.labels}) )  #mnist.test,注意了,x,y只是占位符,train test都可以用
#accuracy,注意了,这个时候w,b不会再变了,所以x进去自然会有一个y出来;feed_dict表示输入字典

ss


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值