from __future__ import print_function#即使是在python2版本也要像在Python3中使用print函数 from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("/tmp/data/",one_hot=True)#onehot对标签的标注,非onehot是1,2,3.onehot就是只有一个1其余全是0 import tensorflow as tf #超参数(学习率,batch的大小,训练的轮数,多少轮展示一下loss) learning_rate = 0.1 num_step = 500 batch_size = 128 display_step =100 #网络参数(有多少层网络,每层有多少个神经元,整个网络的输入是多少维度的,输出是多少维度的) n_hidden_1 = 256 n_hidden_2 = 256 num_input = 784#(28*28) num_class = 10 #图的输入 X = tf.placeholder("float",[None,num_input]) Y = tf.placeholder("float",[None,num_class]) #网络的权重和偏向,如果是两个隐层的话需要定义三个权重,包括输出层 weights={ 'h1':tf.Variable(tf.random_normal([num_input,n_hidden_1])), 'h2':tf.Variable(tf.random_normal([n_hidden_1,n_hidden_2])), 'out':tf.Variable(tf.random_normal([n_hidden_2,num_class])) } biase = { 'b1':tf.Variable(tf.random_normal([n_hidden_1])), 'b2':tf.Variable(tf.random_normal([n_hidden_2])), 'out':tf.Variable(tf.random_normal([num_class])) } #定义网络结构 def neural_net(x): layer_1 = tf.add(tf.matmul(x,weights['h1']),biase['b1']) layer_2 = tf.add(tf.matmul(layer_1,weights['h2']),biase['b2']) out_layer = tf.add(tf.matmul(layer_2,weights['out']),biase['out']) return out_layer #模型输出处理 logits = neural_net(X) prediction = tf.nn.softmax(logits) #定义损失和优化器 loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits,labels=Y)) optimizer = tf.train.AdamOptimizer(learning_rate = learning_rate) train_op = optimizer.minimize(loss_op) #评估模型准确率 correct_pred = tf.equal(tf.argmax(prediction,1),tf.argmax(Y,1)) accuracy = tf.reduce_mean(tf.cast(correct_pred,tf.float32)) #初始化变量 init = tf.global_variables_initializer() #开始训练 with tf.Session() as sess: sess.run(init) for step in range(1,num_step+1): batch_x,batch_y = mnist.train.next_batch(batch_size) if step % display_step == 0 or step == 1: loss,acc = sess.run([loss_op,accuracy],feed_dict={X:batch_x,Y:batch_y}) print("step:{},loss:{},acc:{}".format(step,loss,acc)) print("优化完成!") #训练完模型后,开始测试 print("testing Accuracy:",sess.run(accuracy,feed_dict={X:mnist.test.images,Y:mnist.test.labels}))
tensorflow实现手写数字识别(MLP)
最新推荐文章于 2024-12-11 07:00:00 发布