基于tensorflow实现mnist手写数字的识别
在优快云上看了不少的博客文章,学到了很多。今后会把自己的一些浅显理解写出来,算是对自己学习过程的记录吧。
本文参考自:[地址](https://blog.youkuaiyun.com/jerr__y/article/details/61195257)
准备阶段:导入相关模块
from __future__ import print_function
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.ops import rnn
# 导入数据集,并打印数据形式,及tensorflow版本
mnist = input_data.read_data_sets("./MNIST_data", one_hot=True)
print(mnist.train.images.shape)
print(tf.__version__)
输出结果
# Extracting ./MNIST_data/train-images-idx3-ubyte.gz
# Extracting ./MNIST_data/train-labels-idx1-ubyte.gz
# Extracting ./MNIST_data/t10k-images-idx3-ubyte.gz
# Extracting ./MNIST_data/t10k-labels-idx1-ubyte.gz
# (55000, 784)
# 0.8.0
我所用的是0.8.0的tensorflow 不同版本一些调用语句会有所不同,自行查阅资料调试即可。
参数设置
# 设置学习率,为后面训练阶段做准备
lr = 1e-3
# 设置batch_size,为便于调整,这里先采用占位符,数据类型为tf.int32
batch_size = tf.placeholder(tf.int32)
# 每次输入一个28维的向量
input_size = 28
# 时序长度也设为28 ,输入28个向量后做一次预测,恰好是一张手写数字图
timestep_size=28
# 每个隐层节点数,用来记忆和储存状态的节点数
hidden_size=256
# 两层LSTM layer
layer_num=2
# 类别数 0-9 共10类
class_num=10
# 输入 输出和keep_prob设置
_X=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,class_num])
keep_prob=tf