代码:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @ Date : 2022/12/21 13:19
# @ Author : paperClub
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' # 去掉警告
import tensorflow
if int(tensorflow.__version__[0]) == 2:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
else:
import tensorflow as tf
def load_model(pb_model_file):
graph = tf.Graph()
with graph.as_default():
graph_def = tf.GraphDef()
graph_def.ParseFromString(open(pb_model_file, 'rb').read())
tensors = tf.import_graph_def(graph_def, name="")
session = tf.Session(graph=graph)
with session.as_default():
with graph.as_default():
init = tf.global_variables_initializer()
session.run(init)
session.graph.get_operations()
return session
session = None
if session is None:
pb_model_file = "./tf2_model.pb"
session = load_model(pb_model_file)
layer_input = 'input_1

本文介绍了一个使用Python和TensorFlow加载PB格式模型的方法,并演示了如何利用该模型进行预测。文中提供了完整的代码示例,包括设置环境变量、选择TensorFlow版本以及定义加载模型的函数等。
最低0.47元/天 解锁文章
2402

被折叠的 条评论
为什么被折叠?



