背景:tensorflow需要加载checkpoint文件模型或者pb格式模型。
问题:怎么加载保存的checkpoint并预测,加载pb格式文件加载时多次预测会内存泄漏。
解决方法
1、checkpoint格式加载与预测
# 加载
self.graph = tf.Graph() # 为每个类(实例)单独创建一个graph
with self.graph.as_default():
self.check_point_file = tf.train.latest_checkpoint(self.model_path)
self.saver = tf.train.import_meta_graph("{}.meta".format(self.check_point_file)) # 创建恢复器
# 注意!恢复器必须要在新创建的图里面生成,否则会出错。
self.sess = tf.Session()
with self.sess.as_default():
self.saver.restore(self.sess, self.check_point_file)
# 预测
def predict():
with self.graph.as_default():
with self.sess.as_default():
input_x = self.graph.get_tensor_by_name("input_x:0")
input_y = self.graph.get_tensor_by_name("input_y:0")
q_y_raw = self.graph.get_tensor_by_name("representation/q_y_raw:0")
qs_y_raw = self.graph.get_tensor_by_name("representation/qs_y_raw:0")
qs_y_raw_out = self.sess.run(qs_y_raw, feed_dict={input_y: np.array(input_y_value, dtype=np.int32)})
2、pb格式 加载与预测
# 加载
self.graph = tf.Graph() # 为每个类(实例)单独创建一个graph
with self.graph.as_default():
self.sess = tf.Session(graph=self.graph)
with self.sess.as_default():
output_graph_def = tf.GraphDef()
pb_path = wenlp_configs["sentence_matcher"]["pb_model_path"]
with open(pb_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name="")
# 预测
with tf.Session(graph=self.graph) as self.sess:
input_x = self.graph.get_tensor_by_name("input_x:0")
input_y = self.graph.get_tensor_by_name("input_y:0")
q_y_raw = self.graph.get_tensor_by_name("representation/q_y_raw:0")
qs_y_raw = self.graph.get_tensor_by_name("representation/qs_y_raw:0")
qs_y_raw_out = self.sess.run(qs_y_raw, feed_dict={input_y: np.array(input_y_value, dtype=np.int32)})