tensorflow之pb文件保存与载入

本文详细介绍了TensorFlow中PB文件的用途与生成过程,包括如何从.ckpt文件转换为PB文件,以及如何调用PB文件进行预测。文章还提供了代码示例,帮助读者理解输入输出节点的指定方式。

pb是protocol(协议) buffer(缓冲)的缩写。TensorFlow训练模型后存成的pb文件,是一种表示模型(神经网络)结构的二进制文件,将图中的变量保存成为常量,便于调用,一般无法将pb文件映射成源代码。pb文件的能够保存tensorflow计算图中的操作节点以及对应的各张量,方便我们日后直接调用之前已经训练好的计算图。

注:pb文件可以在训练时直接保存,也可以用.ckpt文件转化为.pb文件。

准备所需的文件:

1、训练完成的.index,.data,.mate文件。

2、编写.ckpt文件转.pb文件的代码。

output_node_names = ["Input/X_placeholder", "Inference/output"]#指定输入输出节点名
def freeze_pb(pb_file,ckpt_path):

    with tf.name_scope('Input'):
        input_data = tf.placeholder(dtype=tf.float32,shape=[None, 784], name='X_placeholder')

    with tf.name_scope('Inference'):
        # batch:20 输入:784,通道:1,输出:10
        W = tf.Variable(initial_value=tf.random_normal(shape=[784,10], stddev=0.01), name='Weights')
        b = tf.Variable(initial_value=tf.zeros(shape=[10]), name='bias')
        print(W)
        logits = tf.matmul(input_data, W) + b
        pred = tf.nn.softmax(logits=logits,name='output')


    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    saver = tf.train.Saver()
    saver.restore(sess, ckpt_path)

    converted_graph_def = tf.graph_util.convert_variables_to_constants(sess,
                                                                       input_graph_def=sess.graph.as_graph_def(),
                                                                       output_node_names=output_node_names)

    with tf.gfile.GFile(pb_file, "wb") as f:
        f.write(converted_graph_def.SerializeToString())

注:

1、在将ckpt文件转为pb文件的时候,一定要将输入节点名称和输出节点名称与设计的网络中的输入节点名称和输出节点名称对应起来,否则会报错。例如这里将

Input/X_placeholder改为Input/X_placeholder_e时会有如下的报错。

AssertionError: Input/X_placeholdere_e is not in graph.

output_node_names = ["Input/X_placeholder", "Inference/output"]指定输入输出。

2、tf.graph_util.convert_variables_to_constants中的output_node_names变量来指定保存的节点名称而不是张量的名称,“Input/X_placeholder:0”是张量的名称而"Input/X_placeholder"表示的是节点的名称。在固化pb文件的时候用节点名称,在调用pb文件是使用张量名称。

执行上述代码后会在指定的目录下生成pb文件。

3、编写调用pb文件代码。

return_enement = ["Input/X_placeholder:0", "Inference/output:0"]#指定输入输出张量名称
def Load_PbFile(pb_file,image):
    imagedata=np.array(image).reshape([1,784])
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True


    with tf.gfile.FastGFile(pb_file, 'rb') as f:
        frozen_graph_def = tf.GraphDef()
        frozen_graph_def.ParseFromString(f.read())
        elements = tf.import_graph_def(frozen_graph_def,return_elements=return_enement)

    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        prob=sess.run([elements[1]],feed_dict={elements[0]:imagedata})

        tolist=list(prob[0][0])
        classes = tolist.index(max(tolist))
        print(classes)

这里return_enement = ["Input/X_placeholder:0", "Inference/output:0"]指定的是张量名称,区别于output_node_names = ["Input/X_placeholder", "Inference/output"],否则在运行时会出现TypeError: Cannot interpret feed_dict key as Tensor: Can not convert a Operation into a Tensor.的错误。

<think>好的,用户问的是TensorFlowpb文件通常是什么。我需要先回忆一下pb文件的相关知识。首先,pb文件应该是指Protocol Buffers格式的文件,但具体在TensorFlow里有什么用途呢?可能涉及到保存模型的结构和参数,对吧? 用户可能是个刚开始接触TensorFlow的开发者,或者是在部署模型时遇到了pb文件,需要了解其基本概念。他们可能想知道pb文件的作用、如何生成以及如何使用。深层需求可能是想了解如何保存和加载模型,或者进行模型部署。 接下来,我需要分步骤解释。首先,pb文件的全称是Protocol Buffers,是Google开发的一种数据序列化格式。在TensorFlow中,pb文件通常保存了训练好的模型,包括计算图结构和变量数据。这里要强调它和checkpoint文件的区别,checkpoint只保存参数,而pb文件包含图和参数,适合部署。 然后,生成pb文件的方法。用户可能想知道如何从代码中导出pb文件,所以需要提到冻结图(freeze_graph)的过程,或者使用tf.saved_model.simple_save这样的API。比如将变量转为常量,并保存为.pb文件。 使用场景方面,pb文件用于跨平台部署,比如移动端或嵌入式设备,或者用在其他语言中调用。这时候需要说明pb文件的优点,比如体积小、加载快,包含完整模型结构。 结构部分,可以简要说明pb文件内部包含计算图、操作节点、张量信息,可能还有变量数据。这部分不需要太深入,但要让用户理解其内容。 最后,注意用户可能遇到的常见问题,比如生成pb文件时的错误,或者加载时的版本兼容性问题,但根据用户的问题,可能不需要展开,只要给出基本信息即可。 检查是否需要提到SavedModel格式,因为有时候pb文件也可能指SavedModel中的内容,但通常单独的pb文件是冻结后的模型。需要区分清楚,避免混淆。 确保使用中文,结构清晰,分点说明,避免使用专业术语过多,必要时举例。比如保存为.pb文件的代码示例,或者使用freeze_graph工具的方法。 最后总结pb文件的作用和重要性,帮助用户形成完整的认识。</think>TensorFlow的`.pb`文件(Protocol Buffers格式)是用于保存模型结构和参数的序列化文件。以下是逐步说明: ### 1. **核心定义** `.pb`文件TensorFlow模型的**冻结格式**(Frozen Graph),通常包含: - **计算图结构**(Graph Definition):模型的所有层、操作和连接关系 - **固化后的参数**:训练好的权重和偏置等参数被转换为常量存储 ### 2. **生成方式** 通过以下步骤生成: ```python # 示例:将模型保存为.pb格式 import tensorflow as tf # 定义并训练模型后... graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( sess, # 当前会话 sess.graph.as_graph_def(), # 原始计算图 output_node_names=['output_layer'] # 指定输出节点 ) # 写入文件 with tf.io.gfile.GFile('model.pb', 'wb') as f: f.write(graph_def.SerializeToString()) ``` ### 3. **关键特性** - **独立部署**:无需原始训练代码即可加载 - **跨平台**:支持移动端(TensorFlow Lite)、服务端(TensorFlow Serving)等 - **高效推理**:计算图已优化,适合生产环境 ### 4. **Checkpoint的区别** | 文件类型 | 内容 | 用途 | |----------------|-----------------------|--------------| | `.ckpt` | 变量参数(未固化) | 训练恢复 | | `.pb` | 固化图+参数 | 部署推理 | ### 5. **典型应用场景** - 使用`tf.import_graph_def`加载模型: ```python with tf.io.gfile.GFile('model.pb', 'rb') as f: graph_def = tf.compat.v1.GraphDef() graph_def.ParseFromString(f.read()) with tf.Graph().as_default() as graph: tf.import_graph_def(graph_def, name='') ``` ### 6. **文件验证** 可通过工具检查内容: ```bash saved_model_cli show --dir /path/to/model --tag_set serve --signature_def serving_default ``` ### 总结 `.pb`文件TensorFlow模型的标准部署格式,通过固化计算图和参数实现轻量化、高效率的推理。相较于其他保存方式(如HDF5),更适合生产环境中的跨平台使用。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值