将tensorflow的ckpt模型转化为pb模型,可以大大提高网络预测速度,是进行部署的第一步。怎么做参考:这里。我看网上资料较少,我写一下怎么读取pb模型进行测试,通常落地会采用c++这种更底层的语言。
具体怎么写需要根据网络的测试代码来写,每个网络输入输出不一样,我在下面贴一个写好的只作为参考。
总体步骤:
1.读入pb文件
def freeze_graph_test2(pb_path, test_path):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name="")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
这部分代码直接复制,tf的常规操作,把pb_path替换成自己的pb文件地址就行
2.定义输入
keep_probability = sess.graph.get_tensor_by_name(name="keep