TensorFlow模型图的读取与导出:获得一个pb文件的所有节点名称

在进行深度学习任务时,加载和操作已经训练好的模型是常见的工作流程。TensorFlow作为一个流行的深度学习框架,允许我们读取预训练的模型并将其用于推理。本文将介绍读取TensorFlow模型图的代码,并对其进行技术说明。

代码说明

  1. load_graph: 该函数负责加载给定路径的模型文件(.pb文件)并返回一个TensorFlow图(tf.Graph)。
  2. list_tensor_names: 从图中提取所有张量的名称。
  3. save_tensor_names: 将提取到的张量名称保存到指定的文本文件中。
  4. main: 主要执行流程,负责协调以上函数的调用。

完整代码

import tensorflow as tf
import os
import argparse

# 兼容TensorFlow 2.x的代码
def load_graph(model_dir, model_name):
    """
    加载TensorFlow模型并返回当前图。
    :param model_dir: 模型所在目录
    :param model_name: 模型名称(.pb文件)
    :return: 当前的TensorFlow图
    """
    model_path = os.path.join(model_dir, model_name)
    
    # 使用 TensorFlow 2.x 兼容模式
    tf.compat.v1.disable_eager_execution()  # 关闭急切执行模式以兼容图执行
    with tf.compat.v1.gfile.GFile(model_path, 'rb') as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graph_def, name='')  # 导入模型到图中
    return graph

def list_tensor_names(graph):
    """
    获取图中的所有张量名称。
    :param graph: TensorFlow计算图
    :return: 张量名称的列表
    """
    tensor_names = [tensor.name for tensor in graph.as_graph_def().node]
    return tensor_names

def save_tensor_names(tensor_names, result_file):
    """
    将张量名称保存到文件中。
    :param tensor_names: 张量名称的列表
    :param result_file: 结果保存的文件路径
    """
    with open(result_file, 'w+') as f:
        for tensor_name in tensor_names:
            f.write(tensor_name + '\n')

def main(model_dir, model_name):
    """
    主函数,执行模型加载,张量名称提取和保存操作。
    :param model_dir: 模型路径
    :param model_name: 模型名称
    """
    graph = load_graph(model_dir, model_name)
    tensor_names = list_tensor_names(graph)
    
    # 结果文件路径
    result_file = os.path.join(model_dir, 'result.txt')
    save_tensor_names(tensor_names, result_file)
    print(f"Tensor names have been saved to {result_file}")

if __name__ == '__main__':
    # 使用argparse接收命令行输入
    parser = argparse.ArgumentParser(description="Extract tensor names from a TensorFlow model")
    parser.add_argument('model_dir', type=str, help='Directory of the model')
    parser.add_argument('model_name', type=str, help='Name of the model file (e.g., model.pb)')
    
    args = parser.parse_args()
    main(args.model_dir, args.model_name)

使用说明

运行以下脚本时,用户可以指定模型的路径和名称,程序会自动从模型中提取所有张量名称并保存到result.txt文件中。输出结果为模型图中的所有张量名称列表。

python extract_tensor_names.py "D:/TensorFlow/MyTensorFlow/MyTensorFlow/slim/satellite" "inception_v3_frozen_graph.pb"

执行后,所有张量名称将保存在result.txt文件中。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值