网络模型保存为pb形式(二)

本文提供了一个使用TensorFlow实现的简单循环神经网络(RNN)示例,通过动态RNN处理序列数据,并演示了如何将训练好的模型转换为静态图进行预测。文中包括不同输入数据维度下的模型应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

网上看到的一个程序,运行了下,有些地方还有错,先存下来,有些地方可以借鉴的
import numpy
import tensorflow as tf
from tensorflow import graph_util as tf_graph_util
from tensorflow.contrib import rnn as tfc_rnn


def v1(data):
    with tf.Graph().as_default():
        tf.set_random_seed(1)
        x = tf.placeholder(tf.float32, shape=(None, None, 5))
        _, s = tf.nn.dynamic_rnn(tfc_rnn.BasicRNNCell(7), x, dtype=tf.float32)

        with tf.Session() as session:
            session.run(tf.global_variables_initializer())
            print (session.run(s, feed_dict={x: data}))


def v2a():
    with tf.Graph().as_default():
        tf.set_random_seed(1)
        x = tf.placeholder(tf.float32, shape=(2, 3, 5), name="x")
        _, s = tf.nn.dynamic_rnn(tfc_rnn.BasicRNNCell(7), x, dtype=tf.float32)

        with tf.Session() as session:
            session.run(tf.global_variables_initializer())
            return tf_graph_util.convert_variables_to_constants(
                session, session.graph_def, [s.op.name]), s.name


def v2ba(graph_def, s_name, data):
    with tf.Graph().as_default():
        x, s = tf.import_graph_def(graph_def,
                                   return_elements=["x:0", s_name])

        with tf.Session() as session:
            print ('2ba', session.run(s, feed_dict={x: data}))


def v2bb(graph_def, s_name, data):
    with tf.Graph().as_default():
        x = tf.placeholder(tf.float32, shape=(2, 3, 5))
        [s] = tf.import_graph_def(graph_def, input_map={"x:0": x},
                                  return_elements=[s_name])

        with tf.Session() as session:
            print ('2bb', session.run(s, feed_dict={x: data}))


def v2bc(graph_def, s_name, data):
    with tf.Graph().as_default():
        x = tf.placeholder(tf.float32, shape=(None, None, 5))
        [s] = tf.import_graph_def(graph_def, input_map={"x:0": x},
                                  return_elements=[s_name])

        with tf.Session() as session:
            print ('2bc', session.run(s, feed_dict={x: data}))


def main():
    data1 = numpy.random.random_sample((2, 3, 5))
    data2 = numpy.random.random_sample((1, 3, 5))
    v1(data1)
    model = v2a()
    v2ba(model, data1)
    v2bb(model, data1)
    v2bc(model, data1)
    v2bc(model, data2)


if __name__ == "__main__":
    main()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值