tensorflow持久化以及恢复

1. 保存为ckpt文件

import tensorflow as tf
import numpy as np





if __name__ == '__main__':

    input_data = tf.placeholder(dtype=tf.float32, shape=[2,3],name='input')
    print("input_data.node_name = " + input_data.name)

    p1 = tf.Variable(initial_value=tf.random_normal(shape=[2,3], mean=1.0, stddev=0.5), name='v1')
    p2 = tf.Variable(initial_value=tf.random_normal(shape=[2,3], mean=1.0, stddev=0.5), name='v2')
    pinput = input_data + p1
    p3 = pinput + p2
    print("p3.node_name = " + p3.name)
    W = tf.Variable(initial_value=tf.random_normal(shape=[3,2], mean=0.3, stddev=0.2), name='w')
    wp = tf.matmul(p3, W)
    print("wp.node_name = " + wp.name)

    saver = tf.train.Saver()

    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
        saver.save(sess, 'log/model.ckpt')

        ckpt = tf.train.get_checkpoint_state('log')
        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt.model_checkpoint_path)

#2. 将ckpt转化为pb文件

import tensorflow as tf
from tensorflow.python.framework import graph_util

input_checkpoint = 'log/model.ckpt'
output_node_name = 'MatMul'
output_graph = ''
saver = tf.train.import_meta_graph('log/model.ckpt.meta',clear_devices=True)
graph = tf.get_default_graph()
inout_graph_def = graph.as_graph_def()

with tf.Session() as sess:
    saver.restore(sess,input_checkpoint)
    output_graph_def = graph_util.convert_variables_to_constants(sess=sess,
                                                                 input_graph_def=inout_graph_def,
                                                                 output_node_names=output_node_name.split(','))
    with tf.gfile.GFile('log/FirstPB.pb', 'wb') as f:
        f.write(output_graph_def.SerializeToString())

    print("%d ops in the final graph." % len(output_graph_def.node))

#3 利用pb文件预测

import tensorflow as tf
import numpy as np
from tensorflow.python.framework import graph_util
pb_path = 'log/FirstPB.pb'

with tf.Graph().as_default():
    output_graph_def = tf.GraphDef()
    input_data = np.random.rand(2,3)
    with open(pb_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        tf.import_graph_def(output_graph_def, name="")
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            input_tensor = sess.graph.get_tensor_by_name("input:0")
            output_tensor = sess.graph.get_tensor_by_name("MatMul:0")
            output = sess.run(output_tensor,feed_dict={input_tensor:input_data})
            print("output = {0}".format(output))






评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

城墙郭外斜

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值