机器学习笔记7:TensorFlow进阶之利用CNN训练MNIST
本文的理论基础部分以及参考代码源于TensorFlow中文社区以及aliceyangxi1987的博客。
代码分析及调试
在aliceyangxi1987的博客中,基本的代码思路与中文社区中的思路基本一致,不同的地方在于,博客中的代码将准确率计算的步骤进行封装成一个函数也就是compute_accuracy()函数。整体代码如下:
# coding=utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# number 1 to 10 data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
def compute_accuracy(v_xs, v_ys):
global prediction
y_pre = sess.run(prediction, feed_dict={xs: v_xs, keep_prob: 1})
correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(v_ys,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys, keep_prob: 1})
return result
# 产生随机变量,符合 normal 分布
# 传递 shape 就可以返回weight和bias的变量
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
# 定义2维的 convolutional 图层
def conv2d(x, W):
# stride [1, x_movement, y_movement, 1]
# Must have strides[0] = strides[3] = 1
# strides 就是跨多大步抽取信息
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
# 定义 pooling 图层
def max_pool_2x2(x):
# stride [1, x_movement, y_movement, 1]
# 用pooling对付跨步大丢失信息问题
return tf.nn.max_pool(x, ksize=[1,<