本文内容参考于李嘉璇老师的tensorflow技术解析与实战
可能是因为版本的原因,按照原书的代码并不能很好的实现对mnist数据集的识别精度,此外存储的tensorflow模型加载过程也出现了一些问题,因此在原书代码基础上对代码进行了部分修改。
1、简介
为什么要进行模型存储和加载?
对于训练好的网络模型,如何将其应用在预测数据上?这就使得我们需要把训练好的模型提取出来,方便后续的使用,即模型的存储;对于存储起来的模型如何进行使用?这就牵扯到了模型的加载。
tensorflow的API提供了两种存储和加载模型的方式:
(1)生成检查点文件,扩展名一般是.ckpt,通过在tf.train.Saver对象上调用Saver.save()生成。特点:包含权重和其他在程序中定义的变量,不包含图结构,所以如果需要在另一个程序中使用,需要重新创建图结构,并告诉tensorflow如何处理这些权重;
(2)生成图协议文件,这是一个二进制文件,扩展名为.pb,用tf.train.write_graph()保存,只包含图形结构,不含权重,然后使用tf.import_graph_def()来加载图形。
2、模型存储
修改后的可执行代码为:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import os
mnist = input_data.read_data_sets("/path/to/mnist/dataset/ ",one_hot=True)
sess = tf.InteractiveSession()
X = tf.placeholder(tf.float32,[None,784])
Y = tf.placeholder(tf.float32,[None,10])
#定义权重初始化函数以及权重向量
w_h = tf.Variable(tf.random_normal([784,625], stddev=0.01),name = 'w_1')
w_h2 = tf.Variable(tf.random_normal([625,625], stddev=0.01),name = 'w_2')
w_o = tf.Variable(tf.random_normal([625,10], stddev=0.01),name = 'w_0')
variables_dict = {'w_1':w_h, 'w_2':w_h2, 'w_o':w_o}
#定义模型
def model(X, w_h , w_h2 , w_o, p_keep_input,p_keep_hidden):
X = tf.nn.dropout(X, p_keep_input)
h = tf.nn.relu(tf.matmul(X, w_h))
h = tf.nn.dropout(h, p_keep_hidden)
h2 = tf.nn.relu(tf.matmul(h,w_h2))
h2 = tf.nn.dropout(h2, p_keep_hidden)
return tf.nn.softmax(tf.matmul(h2 ,w_o))
p_keep_input = tf.placeholder(tf.float32)
p_keep_hidden = tf.placeholder(tf.float32)
py_x = model(X, w_h,w_h2,w_o,p_keep_input,p_keep_hidden)
#损失函数定义
cost = tf.reduce_mean(-tf.reduce_mean(Y*tf.log(py_x),reduction_indices=[1]))
train_op = tf.train.RMSPropOptimizer(0.001,0.9).minimize(cost)
predict_op = tf.equal(tf.argmax(py_x,1),tf.argmax(Y,1))
accuracy = tf.reduce_mean(tf.cast(predict_op,tf.float32))
#模型存储相关
ckpt_dir = "./ckpt_dir" #模型存储路径设置
if not os.path.exists(ckpt_dir):
os.makedirs(ckpt_dir)
global_step = tf.Variable(0,name='global_step',trainable = False)#定义计数器,为训练轮数计数
saver = tf.train.Saver(variables_dict)#调用模型存储API,并指定一个变量列表或者字典,传给tf.train.Saver()
non_storable_variable = tf.Variable(777)
#模型训练并存储
with tf.Session() as sess:
tf.global_variables_initializer().run()
start = global_step.eval()
for i in range(start, 1000):
batch = mnist.train.next_batch(100)
if i%100 == 0:
train_accuracy = accuracy.eval(feed_dict={X: batch[0],Y:batch[1],p_keep_hidden:0.8,p_keep_input:0.5}) #keep_prob训练时通常小于1,测试时为1
print("step %d,training accuracy %g"%(i,train_accuracy))
sess.run(train_op,feed_dict={X: batch[0],Y:batch[1],p_keep_hidden:0.8,p_keep_input:0.5})
#for start, end in zip(range(0,len(trX),128),range(128,len(trX)+1),128):
#sess.run(train_op,feed_dict={X:trX[start:end],Y:trY[start:end],p_keep_hidden:0.8,p_keep_input:0.5})
global_step.assign(i).eval()
saver.save(sess, ckpt_dir + "/model.ckpt",global_step=global_step)
trainaccuracy = sess.run(accuracy,feed_dict={X:mnist.test.images,Y:mnist.test.labels,p_keep_hidden:1.0,p_keep_input:1.0})
print("test accuracy %g"%trainaccuracy)
原文代码中直接调用的是saver = tf.train.Saver(),而不是类似于本代码中指定一个变量列表或者字典variable_dict,直接传给tf.train.Saver(variable_dict),原因是,我采用原文方式进行测试时,虽然模型能够顺利的保存,但是在加载时却出现可类似于“Key Variable_xxx not found in checkpoint”的错误,所以参照AAAAAAAAAAAA,我对模型中关键的权重(因为本模型中未加偏置向量)建立了变量列表:
w_h = tf.Variable(tf.random_normal([784,625], stddev=0.01),name = 'w_1')
w_h2 = tf.Variable(tf.random_normal([625,625], stddev=0.01),name = 'w_2')
w_o = tf.Variable(tf.random_normal([625,10], stddev=0.01),name = 'w_0')
variables_dict = {'w_1':w_h, 'w_2':w_h2, 'w_o':w_o}
训练结果:
在ckpt_dir文件下最终会存在有16个文件,其中5个是zuimodel.ckpt-(n).data-00000-of-00001文件,是训练过程中保存的模型,5个model.ckpt-{n}.meta文件,是训练过程中保存的元数据,5个model.ckpt-{n}-index文件,{n}表示迭代次数,以及一个检查点文本文件,里面保存着当前模型和最近的五个模型。
通过观察这几个文件可以发现,它们最后5轮训练的参数模型,其因为tensorflow默认只保存最近五个模型和元数据,删除前面没用的模型和元数据。这么做的原因主要是两个:首先是防止在模型训练过程由于某些原因导致脚本停止运行,对于大数据集长时间的训练,训练了几周两次突然电脑重启,这时重新训练岂不是太浪费时间,所以模型存储便可避免这种意外。此外,其会在每个固定的轮数在检查点保存一个模型,方便随时将模型提取踹进行预测和评估网络。
3、模型加载
如果有已经训练好的模型,只需利用saver.restory进行模型加载:
with tf.Session() as sess:
tf.global_variables_initializer().run()
#ckpt = tf.train.get_checkpoint_state(ckpt_dir)
saver.restore(sess,ckpt_dir)
trainaccuracy = sess.run(accuracy,feed_dict={X:mnist.test.images,Y:mnist.test.labels,p_keep_hidden:1.0,p_keep_input:1.0})
print("test accuracy %g"%trainaccuracy)#feed_dict给placeholder创建的tensor赋值
完整代码:
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
import os
mnist = input_data.read_data_sets("path/to/mnist/dataset/",one_hot=True)
sess = tf.InteractiveSession()
X = tf.placeholder(tf.float32,[None,784])
Y = tf.placeholder(tf.float32,[None,10])
#定义权重初始化函数以及权重向量
w_h = tf.Variable(tf.random_normal([784,625], stddev=0.01),name = 'w_1')
w_h2 = tf.Variable(tf.random_normal([625,625], stddev=0.01),name = 'w_2')
w_o = tf.Variable(tf.random_normal([625,10], stddev=0.01),name = 'w_0')
variables_dict = {'w_1':w_h, 'w_2':w_h2, 'w_o':w_o}
def model(X, w_h , w_h2 , w_o, p_keep_input,p_keep_hidden):
X = tf.nn.dropout(X, p_keep_input)
h = tf.nn.relu(tf.matmul(X, w_h))
h = tf.nn.dropout(h, p_keep_hidden)
h2 = tf.nn.relu(tf.matmul(h,w_h2))
h2 = tf.nn.dropout(h2, p_keep_hidden)
return tf.nn.softmax(tf.matmul(h2 ,w_o))
p_keep_input = tf.placeholder(tf.float32)
p_keep_hidden = tf.placeholder(tf.float32)
py_x = model(X, w_h,w_h2,w_o,p_keep_input,p_keep_hidden)
#损失函数定义
cost = tf.reduce_mean(-tf.reduce_mean(Y*tf.log(py_x),reduction_indices=[1]))
train_op = tf.train.AdamOptimizer(1e-4).minimize(cost)
predict_op = tf.equal(tf.argmax(py_x,1),tf.argmax(Y,1))
accuracy = tf.reduce_mean(tf.cast(predict_op,tf.float32))
ckpt_dir = "./ckpt_dir/model.ckpt-999"
global_step = tf.Variable(0,name='global_step',trainable = False)
saver = tf.train.Saver(variables_dict)
non_storable_variable = tf.Variable(777)
#模型加载及测试
with tf.Session() as sess:
tf.global_variables_initializer().run()
#ckpt = tf.train.get_checkpoint_state(ckpt_dir)
saver.restore(sess,ckpt_dir)
trainaccuracy = sess.run(accuracy,feed_dict={X:mnist.test.images,Y:mnist.test.labels,p_keep_hidden:1.0,p_keep_input:1.0})
print("test accuracy %g"%trainaccuracy)
输出:
test accuracy 0.9589