在有源代码和ckpt的情况下,想进一步获得用于部署到工业的pb文件。
先使用placeholder留一个输入接口,然后搭建模型,得到输出接口的node名称。
import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.framework import graph_io
from core.my_yolo3 import YOLOV3
pb_file = "./yolo_for_boat.pb"
ckpt_file = "./checkpoint/yolov3_test_loss=1.2524.ckpt-299"
output_node_names = ["input/input_data","post_processing/result"]
with tf.name_scope('input'):
input_data = tf.placeholder(dtype=tf.float32, name='input_data')
model = YOLOV3(input_data, trainable=False,score_threshold=0.35,org_hw=(360,640))
开启会话
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver = tf.train.Saver()
saver.restore(sess, ckpt_file)
就到了最关键的部分了,我们需要把图中的variables都变成constant,需要使用
def convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist=None, variable_names_blacklist=None):
frozen_graph = convert_variables_to_constants(sess,sess.graph_def,output_node_names)
graph_io.write_graph(frozen_graph,'./',pb_file,as_text=False)
然后用graph_io.write_graph生成pb就行了,其中as_text必须是False,如果是True,是生成pbtxt的方式。
解释一下convert_variables_to_constants参数的意思。
- sess: 当前会话
- input_graph_def: 等于sess.graph_def,不变的,也等价于sess.graph.as_graph_def()
- output_node_names : 输出结点的名称,但实际上这个东西好像没啥用。