TensorFlow由ckpt转Pb文件的方法

本文介绍如何将训练好的模型从ckpt格式转换为适合工业部署的pb格式。通过使用TensorFlow API,创建输入占位符并指定输出节点名称,再通过会话加载模型权重,并利用convert_variables_to_constants函数将变量转化为常量,最终保存为pb文件。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在有源代码和ckpt的情况下,想进一步获得用于部署到工业的pb文件。

 

先使用placeholder留一个输入接口,然后搭建模型,得到输出接口的node名称。

 

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.framework import graph_io

from core.my_yolo3 import YOLOV3



pb_file = "./yolo_for_boat.pb"
ckpt_file = "./checkpoint/yolov3_test_loss=1.2524.ckpt-299"
output_node_names = ["input/input_data","post_processing/result"]

with tf.name_scope('input'):
    input_data = tf.placeholder(dtype=tf.float32, name='input_data')
model = YOLOV3(input_data, trainable=False,score_threshold=0.35,org_hw=(360,640))   

开启会话

sess  = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver = tf.train.Saver()
saver.restore(sess, ckpt_file)

就到了最关键的部分了,我们需要把图中的variables都变成constant,需要使用

 

def convert_variables_to_constants(sess,
                                   input_graph_def,
                                   output_node_names,
                                   variable_names_whitelist=None,
                                   variable_names_blacklist=None):

 

frozen_graph = convert_variables_to_constants(sess,sess.graph_def,output_node_names)

graph_io.write_graph(frozen_graph,'./',pb_file,as_text=False)

然后用graph_io.write_graph生成pb就行了,其中as_text必须是False,如果是True,是生成pbtxt的方式。

解释一下convert_variables_to_constants参数的意思。

  • sess: 当前会话
  • input_graph_def: 等于sess.graph_def,不变的,也等价于sess.graph.as_graph_def()
  • output_node_names : 输出结点的名称,但实际上这个东西好像没啥用。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值