1. 保存为ckpt文件
import tensorflow as tf
import numpy as np
if __name__ == '__main__':
input_data = tf.placeholder(dtype=tf.float32, shape=[2,3],name='input')
print("input_data.node_name = " + input_data.name)
p1 = tf.Variable(initial_value=tf.random_normal(shape=[2,3], mean=1.0, stddev=0.5), name='v1')
p2 = tf.Variable(initial_value=tf.random_normal(shape=[2,3], mean=1.0, stddev=0.5), name='v2')
pinput = input_data + p1
p3 = pinput + p2
print("p3.node_name = " + p3.name)
W = tf.Variable(initial_value=tf.random_normal(shape=[3,2], mean=0.3, stddev=0.2), name='w')
wp = tf.matmul(p3, W)
print("wp.node_name = " + wp.name)
saver = tf.train.Saver()
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
saver.save(sess, 'log/model.ckpt')
ckpt = tf.train.get_checkpoint_state('log')
if ckpt and ckpt.model_checkpoint_path:
print(ckpt.model_checkpoint_path)
#2. 将ckpt转化为pb文件
import tensorflow as tf
from tensorflow.python.framework import graph_util
input_checkpoint = 'log/model.ckpt'
output_node_name = 'MatMul'
output_graph = ''
saver = tf.train.import_meta_graph('log/model.ckpt.meta',clear_devices=True)
graph = tf.get_default_graph()
inout_graph_def = graph.as_graph_def()
with tf.Session() as sess:
saver.restore(sess,input_checkpoint)
output_graph_def = graph_util.convert_variables_to_constants(sess=sess,
input_graph_def=inout_graph_def,
output_node_names=output_node_name.split(','))
with tf.gfile.GFile('log/FirstPB.pb', 'wb') as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
#3 利用pb文件预测
import tensorflow as tf
import numpy as np
from tensorflow.python.framework import graph_util
pb_path = 'log/FirstPB.pb'
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
input_data = np.random.rand(2,3)
with open(pb_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
input_tensor = sess.graph.get_tensor_by_name("input:0")
output_tensor = sess.graph.get_tensor_by_name("MatMul:0")
output = sess.run(output_tensor,feed_dict={input_tensor:input_data})
print("output = {0}".format(output))