import graph from file,导入已有的图模型,用tensorboard查看图模型

本文介绍如何使用TensorFlow通过GraphDef加载预训练模型,并展示了如何利用这些模型进行预测。此外,还提供了使用TensorBoard可视化模型结构的方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

graph_def = tf.GraphDef()     # 新建GraphDef文件,用于临时载入模型中的图
graph_def.ParseFromString(f.read())      # GraphDef加载模型中的图
tf.import_graph_def(graph_def, name='')       # 在当前默认图中加载GraphDef中的图

 函数tf.import_graph_def的定义如下

def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None,
                     producer_op_list=None):
    pass
# import graph from file
with tf.gfile.GFile("retrained_graph.pb", 'rb') as f:
    # 新建GraphDef文件,用于临时载入模型中的图
    graph_def = tf.GraphDef()
    # GraphDef加载模型中的图
    graph_def.ParseFromString(f.read())
    # 在空白图中加载GraphDef中的图
    tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
    # Feed the image_data as input to the graph and get first prediction
    softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
    predictions = sess.run(softmax_tensor, {'DecodeJpeg/contents:0': image_data})

另外一个

def create_inception_graph():
  """"Creates a graph from saved GraphDef file and returns a Graph object.

  Returns:
    Graph holding the trained Inception network, and various tensors we'll be
    manipulating.
  """
  with tf.Session() as sess:
    model_filename = os.path.join(
        FLAGS.model_dir, 'classify_image_graph_def.pb')
    with gfile.GFile(model_filename, 'rb') as f:
        # 新建GraphDef文件,用于临时载入模型中的图
      graph_def = tf.GraphDef()
        # GraphDef加载模型中的图
      graph_def.ParseFromString(f.read())
        # 在空白图中加载GraphDef中的图
      bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = (
          tf.import_graph_def(graph_def, name='', return_elements=[
              BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME,
              RESIZED_INPUT_TENSOR_NAME]))
  return sess.graph, bottleneck_tensor, jpeg_data_tensor, resized_input_tensor

用tensorboard查看图模型

from tensorflow as tf 
model = 'model.pb' 
graph = tf.get_default_graph() 
graph_def = graph.as_graph_def() 
graph_def.ParseFromString(tf.gfile.FastGFile(model, 'rb').read()) tf.import_graph_def(graph_def, name='graph') 
summaryWriter = tf.summary.FileWriter('log/', graph)

# 使用tensorboard命令查看
tensorboard --logdir DIR --host IP --port PORT
# 一般情况下,不设置host和port,就会在localhost:6006启动。DIR是路径(不加引号)。

 

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值