0.前言
首先我们要理解TensorFlow的一个规则,首先构建计算图(graph),然后初始化graph中的data,这两步是分开的。
1.如何恢复模型
有两种方式(这两种方式有比较大的不同):
1.1 重新使用代码构建图
举个例子(完整代码):
def build_graph():
w1 = tf.Variable([1,3,10,15],name='W1',dtype=tf.float32)
w2 = tf.Variable([3,4,2,18],name='W2',dtype=tf.float32)
w3 = tf.placeholder(shape=[4],dtype=tf.float32,name='W3')
w4 = tf.Variable([100,100,100,100],dtype=tf.float32,name='W4')
add = tf.add(w1,w2,name='add')
add1 = tf.add(add,w3,name='add1')
return w3,add1
with tf.Session() as sess:
ckpt_state = tf.train.get_checkpoint_state('./temp/')
if ckpt_state:
w3,add1=build_graph()
saver = tf.train.Saver()
saver.restore(sess, ckpt_state.model_checkpoint_path)
else:
w3,add1=build_graph()
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(init_op)
saver = tf.train.Saver()
a = sess.run(add1,feed_dict={