tensorflow 读取模型并进行预测

本文介绍了如何在TensorFlow中读取和使用两种模型格式:checkpoint(ckpt)和frozen_graph。首先,详细阐述了模型的保存过程,包括checkpoint模型的结构和frozen_graph模型的固化。接着,分别展示了读取ckpt模型和frozen_graph模型进行预测的方法,特别提醒在使用frozen_graph模型时需注意避免Session graph为空的错误。最后,提供了相关博客资源作为参考。

tensorflow 读取两种格式的模型并进行预测

1. 模型保存

1.1 checkpoint 模型

如图所示,
.meta – 保存图结构,即神经网络的网络结构
.data – 保存数据文件,即网络的权值,偏置,操作等等
.index – 是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等。
checkpoint – 文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model.

保存模型:

saver = tf.train.Saver()
saver.save(sess, model_path)

其中model_path是模型保存路径。

1.2 frozen_graph模型

在工程中,我们往往需要将模型和权重固化,便于发布和预测。
使用tensorFlow官方提供的freeze_graph.py工具来保存相应模型。(代码中把freeze_graph.py文件放在commom.utils.tf路径下导入)

freeze_graph.py先加载模型文件,从checkpoint文件读取权重数据初始化到模型里的权重变量,再将权重变量转换成权重常量,然后再通过指定的输出节点将没用于输出推理的Op节点从图中剥离掉,再重新保存到指定的文件里(用write_graphdef或Saver)。

from tensorflow.core.protobuf import saver_pb2
from common.utils.tf import freeze_graph
# save model graph
tf.train.write_graph(
    sess.graph.as_graph_def(),
    os.path.join(model_path),
    GRAPH_PB_NAME,
    as_text=False)
# generate frozen graph
freeze_graph.freeze_graph(
    input_graph=os.path.join(model_path, GRAPH_PB_NAME),
    input_saver=False,
    input_binary=True,
    input_checkpoint=os.path.join(model_pat
TensorFlow中,要读取模型进行预测,通常需要遵循以下步骤: 1. **选择模型格式**: - **Checkpoint (ckpt)**: 这是最常见的模型保存格式,包含了变量及其值。如果使用的是checkpoint模型,你需要先加载变量: ```python with tf.Graph().as_default(): sess = tf.Session() saver = tf.train.import_meta_graph('model.meta') # 使用.meta文件导入元图 saver.restore(sess, 'model.ckpt') # 加载具体的ckpt文件 # 获取输入和输出节点,如"input_placeholder:0"和"output_node:0" input_node = sess.graph.get_tensor_by_name('input_placeholder:0') output_node = sess.graph.get_tensor_by_name('output_node:0') # 进行预测 prediction = sess.run(output_node, {input_node: input_data}) ``` - **Frozen Graph (.pb)**: 是经过优化和冻结操作(`tf.train freezing`)后的二进制格式,可以节省存储空间提高加载速度。使用frozen_graph模型,直接加载预训练的pb文件即可: ```python with tf.gfile.GFile('model.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(graph_def, name='') # 获取输入和输出节点 input_node = sess.graph.get_tensor_by_name('input_node:0') output_node = sess.graph.get_tensor_by_name('output_node:0') prediction = sess.run(output_node, {input_node: input_data}) ``` 2. **使用上下文管理器Graph().as_default()**: 这有助于确保在模型预测时不会遇到空图错误。 3. **设置Session**: 创建一个会话来运行计算图。 4. **获取输入和输出节点**: 根据模型结构找到对应的占位符和输出节点。 5. **执行预测**: 提供输入数据,通过会话运行输出节点以得到预测结果。 请注意,具体实现可能因模型结构的不同而有所差异。如果你遇到了"The Session graph is empty"的错误,确认已经正确加载了模型且图中包含了必要的操作。此外,博客[^1]提供了更详细的教程和示例,可供参考。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值