首先读取数据,数据源是mnist库,可以通过input_data中read_data函数直接读取数据,数据图像为28*28。
#导入库
import tensorflow as tf
#下载数据对应的库
import input_data
import numpy as np
import matplotlib.pyplot as plt
print ("Packages imported")
#导入mnist数据
mnist = input_data.read_data_sets("data/", one_hot=True)#one_hot=True 表示 数据的标签是one_hot编码的,即数据标签为1*10的数组
#读取训练数据,训练标签,测试数据,测试标签
trainimgs, trainlabels, testimgs, testlabels \
= mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
#获取训练数据个数,测试集数据个数,图像维度和类别数
ntrain, ntest, dim, nclasses \
= trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
print ("MNIST loaded")
读取数据后设置参数。本次使用LSTM作为训练模型,因此需要搭建LSTM,因图像为28*28,所以将每一行图像作为一次输入,这样每一次训练,LSTM需要运算28次,设置隐层为128,所以从输入到隐层的全连接参数为28*128个,经过运算后输出与隐层全连接参数为128*10,每运行一次的输出为1*10.
LSTM结构如图:
搭建了28个lstm,每一个lstm公用参数,因此也可以看做搭建了一个lstm循环了28次,只有最终的结果有作用。