import tensorflow as tf
from tensorflow.python.tools.freeze_graph import freeze_graph
def freeze(ckpt_dir, meta_file_path, output_node_names):
# 导入结构、加载权重
saver = tf.train.import_meta_graph(meta_file_path)
with tf.Session() as sess:
saver.restore(sess, tf.train.latest_checkpoint(ckpt_dir))
tf.train.write_graph(sess.graph_def, 'temp/', 'temp.pb') # temp/、temp.pb
freeze_graph(input_graph='temp/temp.pb', # temp/temp.pb
input_checkpoint=tf.train.latest_checkpoint(ckpt_dir),
output_graph='frozen_graph.pb',
output_node_names=output_node_names,
# 以下为固定写法
clear_devices=True,
input_binary=False,
input_saver='',
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
initializer_nodes='')
if __name__ == '__main__':
freeze('./ckpt', './ckpt/model.meta', 'input,classifier/Softmax,regression/BiasAdd')
TensorFlow 将 checkpoint 冻结为 frozen_graph
最新推荐文章于 2024-05-19 15:12:42 发布
本文详细介绍如何使用TensorFlow将训练好的模型进行冻结,以便于在没有TensorFlow运行环境的设备上部署。通过具体代码示例,展示从导入结构、加载权重到最终生成冻结模型文件的全过程。
1万+

被折叠的 条评论
为什么被折叠?



