众所周知,Tensorflow框架在机器学习、深度学习领域有着广泛的应用。博主最近就Tensorflow框架进行了简单的学习,当我完成简单的回归模型以及CNN的训练后,一个想法油然而生,怎样将训练好的模存保存呢?保存之后又该怎么调用呢?----------于是,我查阅了很多相关资料,经过很多的测试,终于…成功了!现将如何保存和调用训练好的模型总结如下(提一下,本文以CNN为例,数据集是众所周知的MNIST):
1.模型的保存
模型的保存通常有两种方法,一种是Tensorflow老版本中的t.trans.saver,另一种是新版本中的tf.saved_model.builder,博主使用的是后者,所以对于第一种方法也就一带而过。
1.1 tf.train.saver这种方法是使用较多的,但是呢,这种方法更多地是用来保存变量(变量)的,所以博主也就没有了解太多,下文po一段典型代码:
import tensorflow as tf
...
#在这里构建网络
...
#开始保存模型
与tf.Session()作为sess:
sess.run(tf.global_variables_initializer())#一定要先初始化整个流
#在这里训练网络
...
#保存参数
saver = tf.train.Saver()
saver.save(sess,PATH)#PATH就是要保存的路径
1.2 tf.saved_model.builder
下文只放出保存模型的关键代码,至于训练的过程,会在后续一并放出,整体分析。
将tensorflow import 为tf
...
#构建网络
...
用tf.Session()作为sess:
sess.run(tf.global_variables_initializer())#一定要先初始化整个流
#在这里训练网络
...
#保存参数
builder = tf.saved_model.builder.SaveModelBuilder(PATH)#PATH是保存路径
builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.TRAINING])#保存整张网络及其变量,这种方法是可以保存多张网络的,在此不作介绍,可自行了解
builder.save()#完成保存
2.模型的调用
以上就简单介绍了两种保存模型的方法,接下来介绍一下读取模型的方法,只讲基于tf.saved_model.builder方法的读取模型了>划重点!完整调用并使用一个模型,分为加载网络和加载关键节点两个部分。幸运的是,使用tf.saved_model.builder保存和调用模型,这一切都变得非常简单~~~
2.1加载网络废话不多说,直接po出代码:
将tensorflow import 为tf
用tf.Session(graph = tf.Graph())作为sess:
tf.saved_model.loader.load(sess,[tf.saved_model.tag_constants.TRAINING],PATH)#PATH还是路径
...
...
emmmm …只要通过这么几句代码,就可以完成模型的调用了,但是呢,想要真正地使用这个网络,关键点在于加载网络中的关键节点,如:输入,输出
2.2加载张量(节点,Tensor,变量)神经元,可以是加载张量,也可以是加载节点,变量,代码也非常简单,在上节的基础上,加载张量的代码就只有一句话:
#以加载输入变量为例吧
input_x = sess.graph.get_tensor_by_name('input / input_x:0')#传入的参数会在后文讲到的
3.使用训练好的网络
当把输入的变量以及输出的变量都加载后,就可以使用网络了,假设输出变量为输出,输入变量为X和Y,那么使用这个网络的代码就可以表示如下:
sess.run(output,{
x:_x,y:_y})#大括号的内容就是feed_dict了,使用方法类似于占位符
简要的思路
就结合我的代码来具体分析一下整个过程吧,先说一下简要的思路:
- 训练网络,并构建TensorBoard中的图形图,用于确定网络的关键节点,保存整个网络;
- 通过TensorBoard查找整个的数据流向,确定输入节点,输出节点;
以下是CNN网络的训练和保存部分:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_dat