- 模型保存和还原都是通过TensorFlow的一个API实现的,这个API就是tf.train.saver类
- 保存模型
- eg:
import tensorflow as tf
#声明两个变量并计算他们的和
v1 = tf.Variable(tf.constant(1.0,shape=[1],name="v1"))
v2 = tf.Variable(tf.constant(2.0,shape=[1],name="v2"))
result = v1+v2
init_op = tf.initialize_all_variables()
# 声明tf.train.saver类用于保存模型
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
# 将模型保存到相关路径下
saver.save(sess, "./model/ckpt")
结果为:
- 还原模型
import tensorflow as tf
#使用和保存时一样的代码声明两个变量并计算他们的和
v1 = tf.Variable(tf.constant(1.0,shape=[1],name="v1"))
v2 = tf.Variable(tf.constant(2.0,shape=[1],name="v2"))
result = v1+v2
init_op = tf.initialize_all_variables()
# 声明tf.train.saver类用于保存模型
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,"./model/ckpt" )
print(sess.run(result))
结果为:[3.]
参考:《TensorFlow:实战Google深度学习框架》