import cv2
import tensorflow as tf
from tensorflow.python.framework import graph_util
def freeze_graph(input_checkpoint ,output_graph):
'''
:param input_checkpoint:
:param output_graph: PB模型保存路径
:return:
'''
with tf.name_scope('input'):
input_data = tf.placeholder(dtype=tf.float32, shape=[1, 256, 512, 3], name='input_data') #定义输入节点
net = lanenet.LaneNet(phase='test', net_flag='vgg')
ret = net.inference(input_tensor=input_data, name='lanenet_model') # 网络模型结构
output = tf.identity(ret, name='output_label') # 定义输出节点
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph() # 获得默认的图
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
with tf.Session() as sess:
saver.restore(sess, input_checkp