tensorflow ckpt模型和pb模型获取节点名称,以及ckpt转pb模型

本文介绍如何使用TensorFlow将ckpt文件转换为pb格式,包括读取ckpt文件、创建和导出pb模型的过程。提供了针对不同模型(如pnet、rnet)的具体实现代码,并涉及自定义算子的转换。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

ckpt

from tensorflow.python import pywrap_tensorflow 
checkpoint_path = 'model.ckpt-8000' 
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) 
var_to_shape_map = reader.get_variable_to_shape_map() 
for key in var_to_shape_map: 
    print("tensor_name: ", key)

pb

import tensorflow as tf
import os

model_name = './mobilenet_v2_140_inf_graph.pb'

def create_graph():
    with tf.gfile.FastGFile(model_name, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')

create_graph()
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
    print(tensor_name,'\n')

ckpt转pb

def freeze_graph(input_checkpoint,output_graph):
    '''
    :param input_checkpoint:
    :param output_graph: PB模型保存路径
    :return:
    '''
    output_node_names = "xxx"
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)
        output_graph_def = graph_util.convert_variables_to_constants(  
            sess=sess,
            input_graph_def=input_graph_def,# 等于:sess.graph_def
            output_node_names=output_node_names.split(","))
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString()) 
        print("%d ops in the final graph." % len(output_graph_def.node)) 
 
        for op in graph.get_operations():
            print(op.name, op.values())
import tensorflow as tf
from tensorflow.python.framework import graph_util

def freeze_graph(input_checkpoint, output_node_names, output_graph):
    tf.reset_default_graph()
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
    with tf.Session() as sess:
        saver.restore(sess, input_checkpoint)
        for n in tf.get_default_graph().as_graph_def().node:
            print(n.name)
        output_graph_def = graph_util.convert_variables_to_constants(
            sess=sess,
            input_graph_def=tf.get_default_graph().as_graph_def(),
            output_node_names=output_node_names.split(","))
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))

if __name__ == '__main__':

   ## ----pnet----
   input_checkpoint="../models/ckpt/pnet_new"
   out_pb_path="../models/pb/pnet.pb"
   output_node_names = "pnet/cls_prob,pnet/conv4_2/BiasAdd"
   freeze_graph(input_checkpoint, output_node_names, out_pb_path)

   print("-----------------------------------------")
   
   ## ----rnet----
   input_checkpoint="../models/ckpt/rnet_new"
   out_pb_path="../models/pb/rnet.pb"
   output_node_names = "rnet/cls_prob,rnet/fc3/BiasAdd"
   freeze_graph(input_checkpoint, output_node_names, out_pb_path)

tensorflow自定义算子的ckpt转pb

#coding:utf-8

import tensorflow as tf
import os
import numpy as np
import sys
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import graph_util
from google.protobuf import text_format
from tensorflow.python.tools import strip_unused_lib
from tensorflow.python.framework import dtypes
sys.path.append("../code/acc/core/")
import model

def freeze_graph_pnet():
    with tf.Graph().as_default() as graph_old:
    #    img_input = tf.placeholder(tf.float32, shape=(None, None, None, 1))
        input_data = tf.placeholder(tf.float32, shape=(1,None, None, 1),name='input')

        emb = model.P_Net(input_data, is_training=False,is_quantize = True)
        print(emb)

        isess = tf.InteractiveSession()
        ckpt_filename = '../models/ckpt_clean/pnet-54'

        isess.run(tf.global_variables_initializer())
        saver = tf.train.import_meta_graph('../models/ckpt_clean/pnet-54.meta')
        saver.restore(isess, ckpt_filename)
        
        with tf.gfile.FastGFile('../models/pb/pnet-54_frozen_graph.pbtxt', 'wb') as f:
            f.write(text_format.MessageToString(isess.graph_def)) 
        
        constant_graph = graph_util.convert_variables_to_constants(isess, isess.graph_def, ['pnet/cls_prob','pnet/conv4_2/BiasAdd'])
        constant_graph = graph_util.remove_training_nodes(constant_graph)
        constant_graph = strip_unused_lib.strip_unused(constant_graph,
                                                    ['input'],
                                                    ['pnet/cls_prob','pnet/conv4_2/BiasAdd'],
                                                    dtypes.float32.as_datatype_enum)

ckpt模型

from tensorflow.python import pywrap_tensorflow
import numpy as np

def get_value_frome_ckpt(ckpt_filename):

    reader = pywrap_tensorflow.NewCheckpointReader(ckpt_filename)
    var_to_shape_map = reader.get_variable_to_shape_map()
    for key in var_to_shape_map:
        #print("tensor_name: ", key)
        if key.find("/norm") != -1 and "local_step" not in key and "biased" not in key :
            print("tensor_name: ", key)
            value = reader.get_tensor(key)
            print(value)

ckpt_filename =''
get_value_frome_ckpt(ckpt_filename)

pb模型输出每一层节点、值、形状

import tensorflow as tf
import cv2
import numpy as np

HEIGHT=48
WIDTH=48

image = cv2.imread("COCO_val2014_000000000192.jpg")
image = cv2.resize(image, (WIDTH, HEIGHT), interpolation=cv2.INTER_CUBIC)
image_np = np.array(image)
image_np_expanded = np.expand_dims(image_np, axis=0).astype(np.float32)

with tf.Graph().as_default(): 
    with tf.gfile.FastGFile("onet.pb","rb") as modelfile: 
        graph_def=tf.GraphDef() 
        graph_def.ParseFromString(modelfile.read()) 
        tf.import_graph_def(graph_def)
        with tf.Session() as sess:
            for n in tf.get_default_graph().as_graph_def().node:
                #print(n.name)
                name = tf.get_default_graph().get_tensor_by_name(n.name+":0")
                value = tf.get_default_graph().get_tensor_by_name(n.name+":0").eval(feed_dict={tf.get_default_graph().get_tensor_by_name("import/Placeholder_2:0"): image_np_expanded})
                print(name,value.shape)

去除不必要的op,修改op如tower,grad等,更改输入

def freeze_graph_preprocess1(ckpt_filename_pnet):
    
    print("ckpt_filename:  " + ckpt_filename_pnet)
    input_data = tf.placeholder(tf.float32, shape=(1,33, 33, 1),name="input")
    model.P_Net(input_data, is_training=False,is_quantize = False)
    
    isess = tf.InteractiveSession()
    isess.run(tf.global_variables_initializer())
    
    saver = tf.train.Saver()
    saver.restore(isess, ckpt_filename_pnet)

    saver.save(isess,"../models/ckpt/pnet_new")

    #from tensorflow.python.framework import graph_io
    #graph_io.write_graph(isess.graph, './model/ckpt-graph','model_%s.pbtxt' % ( bit_width.replace(",","_") ) )
    
    print("success")
评论 18
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值