tf.train.NewCheckpointReader('path'):path是保存的路径,这个函数可以得到保存的所有变量
例如:
先保存一个模型,参数为v,v1.import tensorflow as tf;
import numpy as np;
import matplotlib.pyplot as plt;
v = tf.Variable(0, dtype=tf.float32, name='v')
v1 = tf.Variable(0, dtype=tf.float32, name='v1')
result = v + v1
x = tf.placeholder(tf.float32, shape=[1], name='x')
test = result + x
init = tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
saver.save(sess, "/home/penglu/Desktop/lp/model.ckpt") 利用tf.train.NewCheckpointReader导出所有变量import tensorflow as tf;
import numpy as np;
import matplotlib.pyplot as plt;
reader = tf.train.NewCheckpointReader("/home/penglu/Desktop/lp/model.ckpt")
variables = reader.get_variable_to_shape_map()
for ele in variables:
print ele输出:
v1
v
本文介绍了如何使用TensorFlow保存模型及其变量,并演示了通过tf.train.NewCheckpointReader从已保存的模型中读取变量的方法。
350





