Tensorflow 提供了一个非常简单的API 来保存和还原一个神经网络模型这个API就是tf.train.Saver类,一下代码给出了保存Tensorflow计算图的方法
import tensorflow as tf#声明两个变量并计算他们的和
v1 = tf.Variable(tf.constant(1.0, shape = [1]), name = "v1")
v2 = tf.Variable(tf.constant(1.0, shape = [1]), name = "v2")
result = v1+v2
init_op = tf.global_variable_initializer()
#声明tf.train.Saver 类用于保存模型
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(inti_op)
#将模型保存到 /path/to/model/model.ckpt 文件
saver.save(sess,"/path/to/model/model.ckpt")
以上代码实现了持久化一个简单的TensorFlow模型的功能,在这段代码中saver.save函数将TensorFlow模型保存到 /path/to/model/model.ckpt 文件。TensorFlow模型一般会存在后缀为.ckpt的文件中。虽然以上程序只指定了一个文件路径,但是在这个文件目录下回出现三个文件。这是因为Tensor Flow会将计算图的结构和图上的参数分开来保存。
上面这段代码生成的第一个文件为model.ckpt.meta, 它保存了TensorFlow模型的计算图结构 ,可以简单的理解为神经网络的网络结构,第二个文件为model.ckpt,保存TensorFlow程序中每个变量的取值,最后一个文件为checkpoint文件,保存了一个目录下所有的模型文件列表
以下代码中给出了加载这个已经保存的TensorFlow模型的方法
import tensorflow as tf#使用和保存模型代码中一样的方式来声明变量v1 = tf.Variable(tf.constant(1.0, shape = [1]), name = "v1")v2 = tf.Variable(tf.constant(1.0, shape = [1]), name = "v2")result = v1+v2 saver = tf.train.Saver()with tf.Session() as sess: #加载已经保存的模型,并通过已经保存的模型中变量的值来计算加法 sess.restore(sess,"/path/to/model/model.ckpt") print sess.run(result )
两段代码唯一不同的是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。如果不希望重复定义图上的运算,也可以直接加载已经持久化的图,以下代码给出了一个样例
import tensorflow as tf
#直接加载持久化的图
saver = tf.train.import_meta_graph("/path/to/model/model.ckpt/model.ckpt.meta")
with tf.Session() as sess:
sess.restore(sess,"/path/to/model/model.ckpt")
#通过张量的名称来获取张量
print sess.run( tf.get_default_graph().get_tensor_by_name("add:0"))
#输出[3.]
以上程序中默认加载了tensorFlow计算图中上定义的全部变量,但有时可能只需要保存和加载部分变量。下节介绍如何保存或者加载部分变量