tf.nn.rnn_cell.BasicRNNCell(n_hidden)这个参数就是隐藏神经元的个数。
例如:
import tensorflow as tf
batch_size = 4
input = tf.random_normal(shape=[3, batch_size, 6], dtype=tf.float32)
cell = tf.nn.rnn_cell.BasicRNNCell(10)
init_state = cell.zero_state(batch_size, dtype=tf.float32)
output, final_state = tf.nn.dynamic_rnn(cell, input, initial_state=init_state, time_major=True)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(output).shape)
print(sess.run(final_state))
输出:
[[[-0.36194739 -0.5664643 0.71341908 -0.56548703 -0.6058557 0.15607478
-0.10932037 -0.76532066 -0.15569483 0.5749777 ]
[-0.47865775 -0.85153252 0.11955925 -0.47678211 -0.26779744 -0.16315795
-0.85670316 -0.29747197 -0.74362296 -0.11782304]
[-0.35551894 -0.22971147 0.87532502 -0.07564095 -0.3109358