TensorFlow 持久化

本文介绍了如何使用TensorFlow进行模型的持久化存储,并通过实例演示了存储与读取模型的过程,包括变量、计算图及参数文件的处理方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

今天开始研究持久化存储,对于一个模型,我们为了方便,不用把模型的源代码都拿过来,可以只要一个记录图和里面参数的文件就好。迁移学习就是这么做的,我们最后只是把输出个数修改一下就完成了

##保存
首先是存储,就是搭好框架以后存进来


import tensorflow as tf
from tensorflow.python.framework import graph_util

#定义变量
v1=  tf.Variable(tf.constant([12.0],shape=[1]),name="v1")
v2=  tf.Variable(tf.constant([1.0],shape=[1]),name="v2")


#定义placeholder
p1=tf.placeholder(tf.float32,[None,2],name='p1')
p2=tf.placeholder(tf.float32,[None,2],name='p2')


#定义图
result=tf.multiply(v1,v2,name="result") #节点名称 result 可以从tensorboard看到
print(result.name) #张量的名称result:0
result2=v1*v2*2    #节点名称 mul_1   可以从tensorboard看到
print(result2.name)   #张量的名称mul_1:0
result3=v1+v2       #节点名称   add  可以从tensorboard看到
print(result3.name)  #张量的名称add:0

#计算placeholder
result4=tf.subtract(p1,p2,name='result_substract')
print(result4.name)


#存为json格式
server=tf.train.Saver()
server.export_meta_graph('json.ckpt.meda.json',as_text=True)

#写到TensorBoard可视化
writer=tf.summary.FileWriter('log',tf.get_default_graph())
writer.close()


with tf.Session() as sess:
    #申明一个graph对像
    tf.global_variables_initializer().run()
    graph_def=tf.get_default_graph().as_graph_def()

    #把变量转为常量 ,注意 !!!后面的['result','mul_1'],我们是取他的节点名称
    #但是到时候我们读取的时候,是读取张量的名称,就是后面加  ':0'
     output_graph_def=graph_util.convert_variables_to_constants(sess,graph_def,['result','mul_1','result_substract','p1','p2'] )

    #写成pb文件
    with tf.gfile.GFile('model.pb','wb') as f:
        f.write(output_graph_def.SerializeToString())

这个可以程序中生成了graph可视化的log文件夹,在cmd中输入:

tensorboard --logdir= log

然后在chrome浏览器中输入网址 http://localhost:6006 就可以从上图就可以看到图像结构
TensorBoard

##读取


import tensorflow as tf
from tensorflow.python.platform import gfile


with tf.Session() as sess:
    with gfile.FastGFile('model.pb','rb') as f:
        graph_def =tf.GraphDef()
        graph_def.ParseFromString(f.read())
    
    #读取普通的变量
    result=tf.import_graph_def(graph_def,return_elements= ['result:0','mul_1:0'])
    print(sess.run(result))

    #对于placeholder,要从模型中读进来,再给赋值后计算
    result4,p1,p2 = tf.import_graph_def(graph_def, return_elements=['result_substract:0','p1:0','p2:0'])
    print(sess.run(result4,{p1:[[3,3],[4,4]],p2:[[1,1],[2,2]]}))

#输出:

[array([ 12.], dtype=float32), array([ 24.], dtype=float32)]
[[ 2.  2.]
 [ 2.  2.]]

其中,有很多小函数,打印计算节点信息:

  for op in tf.get_default_graph().get_operations():  # 打印模型节点信息
      print (op.name, op.values())

从Session中保存pb文件

 output_graph = "frozen_model.pb"
     graph_def = tf.get_default_graph().as_graph_def()
     output_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph_def, ["Sigmoid"])
     with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
         f.write(output_graph_def.SerializeToString())  # 序列化输出

打印所以ckpt文件中的节点信息


from tensorflow.python import pywrap_tensorflow
import os
def print_tensor():
    reader=pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
    var_to_shape_map=reader.get_variable_to_shape_map()
    for k in sorted(var_to_shape_map):
        print(k, var_to_shape_map[k])

打印检查点所有的变量

from tensorflow.python.tools import inspect_checkpoint as chkp
chkp.print_tensors_in_checkpoint_file(input_checkpoint, tensor_name='', all_tensors=True)

for op in tf.get_default_graph().get_operations(): #打印模型节点信息
    print (op.name, op.values())
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值