在进行深度学习任务时,加载和操作已经训练好的模型是常见的工作流程。TensorFlow作为一个流行的深度学习框架,允许我们读取预训练的模型并将其用于推理。本文将介绍读取TensorFlow模型图的代码,并对其进行技术说明。
代码说明
- load_graph: 该函数负责加载给定路径的模型文件(
.pb
文件)并返回一个TensorFlow图(tf.Graph
)。 - list_tensor_names: 从图中提取所有张量的名称。
- save_tensor_names: 将提取到的张量名称保存到指定的文本文件中。
- main: 主要执行流程,负责协调以上函数的调用。
完整代码
import tensorflow as tf
import os
import argparse
# 兼容TensorFlow 2.x的代码
def load_graph(model_dir, model_name):
"""
加载TensorFlow模型并返回当前图。
:param model_dir: 模型所在目录
:param model_name: 模型名称(.pb文件)
:return: 当前的TensorFlow图
"""
model_path = os.path.join(model_dir, model_name)
# 使用 TensorFlow 2.x 兼容模式
tf.compat.v1.disable_eager_execution() # 关闭急切执行模式以兼容图执行
with tf.compat.v1.gfile.GFile(model_path, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='') # 导入模型到图中
return graph
def list_tensor_names(graph):
"""
获取图中的所有张量名称。
:param graph: TensorFlow计算图
:return: 张量名称的列表
"""
tensor_names = [tensor.name for tensor in graph.as_graph_def().node]
return tensor_names
def save_tensor_names(tensor_names, result_file):
"""
将张量名称保存到文件中。
:param tensor_names: 张量名称的列表
:param result_file: 结果保存的文件路径
"""
with open(result_file, 'w+') as f:
for tensor_name in tensor_names:
f.write(tensor_name + '\n')
def main(model_dir, model_name):
"""
主函数,执行模型加载,张量名称提取和保存操作。
:param model_dir: 模型路径
:param model_name: 模型名称
"""
graph = load_graph(model_dir, model_name)
tensor_names = list_tensor_names(graph)
# 结果文件路径
result_file = os.path.join(model_dir, 'result.txt')
save_tensor_names(tensor_names, result_file)
print(f"Tensor names have been saved to {result_file}")
if __name__ == '__main__':
# 使用argparse接收命令行输入
parser = argparse.ArgumentParser(description="Extract tensor names from a TensorFlow model")
parser.add_argument('model_dir', type=str, help='Directory of the model')
parser.add_argument('model_name', type=str, help='Name of the model file (e.g., model.pb)')
args = parser.parse_args()
main(args.model_dir, args.model_name)
使用说明
运行以下脚本时,用户可以指定模型的路径和名称,程序会自动从模型中提取所有张量名称并保存到result.txt
文件中。输出结果为模型图中的所有张量名称列表。
python extract_tensor_names.py "D:/TensorFlow/MyTensorFlow/MyTensorFlow/slim/satellite" "inception_v3_frozen_graph.pb"
执行后,所有张量名称将保存在result.txt
文件中。