网上看到的一个程序,运行了下,有些地方还有错,先存下来,有些地方可以借鉴的
import numpy
import tensorflow as tf
from tensorflow import graph_util as tf_graph_util
from tensorflow.contrib import rnn as tfc_rnn
def v1(data):
with tf.Graph().as_default():
tf.set_random_seed(1)
x = tf.placeholder(tf.float32, shape=(None, None, 5))
_, s = tf.nn.dynamic_rnn(tfc_rnn.BasicRNNCell(7), x, dtype=tf.float32)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
print (session.run(s, feed_dict={x: data}))
def v2a():
with tf.Graph().as_default():
tf.set_random_seed(1)
x = tf.placeholder(tf.float32, shape=(2, 3, 5), name="x")
_, s = tf.nn.dynamic_rnn(tfc_rnn.BasicRNNCell(7), x, dtype=tf.float32)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
return tf_graph_util.convert_variables_to_constants(
session, session.graph_def, [s.op.name]), s.name
def v2ba(graph_def, s_name, data):
with tf.Graph().as_default():
x, s = tf.import_graph_def(graph_def,
return_elements=["x:0", s_name])
with tf.Session() as session:
print ('2ba', session.run(s, feed_dict={x: data}))
def v2bb(graph_def, s_name, data):
with tf.Graph().as_default():
x = tf.placeholder(tf.float32, shape=(2, 3, 5))
[s] = tf.import_graph_def(graph_def, input_map={"x:0": x},
return_elements=[s_name])
with tf.Session() as session:
print ('2bb', session.run(s, feed_dict={x: data}))
def v2bc(graph_def, s_name, data):
with tf.Graph().as_default():
x = tf.placeholder(tf.float32, shape=(None, None, 5))
[s] = tf.import_graph_def(graph_def, input_map={"x:0": x},
return_elements=[s_name])
with tf.Session() as session:
print ('2bc', session.run(s, feed_dict={x: data}))
def main():
data1 = numpy.random.random_sample((2, 3, 5))
data2 = numpy.random.random_sample((1, 3, 5))
v1(data1)
model = v2a()
v2ba(model, data1)
v2bb(model, data1)
v2bc(model, data1)
v2bc(model, data2)
if __name__ == "__main__":
main()