持久化代码实现
TensorFlow提供了一个非常简单的API来保存和还原一个神经网络模型。这个API是tf.train.Saver类。
以下代码给出保存TensorFlow计算图的方法。
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0,shape=[1],name="v1"))
v2 = tf.Variable(tf.constant(2.0,shape=[1],name="v2"))
result = v1 + v2
init_op = tf.global_variables_initializer()
# 声明tf.train.Saver()类用于保存模型
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
saver.save(sess,"model/model.ckpt")
以上代码实现了持久化一个简单的TensorFlow模型的功能。
上面这段代码会生成第一个文件为model.ckpt.meta,他保存了TensorFlow计算结构。
第二个文件为model.ckpt,这个文件保存了TensorFlow程序每一个变量的取值。
还有一个checkpoint文件,保存了一个目录下所有的模型文件列表。
以下代码中给出加载已保存TensorFlow模型的方法
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0,shape=[1],name="v1"))
v2 = tf.Variable(tf.constant(0.0,shape=[1],name="v2"))
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:
#加载已保存的模型,并通过已保存的模型中的变量的值来计算加法
saver.restore(sess,"model/model.ckpt")
print(sess.run(result)) # 输出3.0(值用的模型里的值)
待更新。。。