神经网络界 的 HelloWorld代码,了解神经网络过程
1.输入
2.通过参数拟合曲线,Helloworld中使用 y=∑ax+b,
3.训练,反向传播,修改a,b的值修正拟合曲线
#3-2 MNIst数据集分类简单版本
import tensorflow as tf;
import numpy as np;
from tensorflow.examples.tutorials.mnist import input_data
#原始数据
mnist=input_data.read_data_sets("MNIST_data",one_hot=True)
x=tf.placeholder(tf.float32,shape=[None,784])
y=tf.placeholder(tf.float32,shape=[None,10])
batch_size=10000;
n_batch=mnist.train.num_examples//batch_size#获取共有多少批次
#创建神经网络计算图
w=tf.Variable(tf.truncated_normal(shape=[784,10],stddev=0.1))
b=tf.Variable(tf.constant(0.1,shape=[10]))
prediction=tf.nn.softmax(tf.matmul(x,w)+b)
#定义损失函数
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
train_step=tf.train.GradientDescentOptimizer(0.2).minimize(loss)
currect_rate=tf.reduce_mean(tf.cast(tf.equal(tf.arg_max(y,1),tf.arg_max(prediction,1)),tf.float32))
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for epoch in range(1221):
for batch in range(n_batch):
x_batch,y_batch=mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:x_batch,y:y_batch})
acc=sess.run(currect_rate,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("epoch:{0},acc:{1}".format(epoch,acc))