tf.nn.rnn_cell.BasicRNNCell(n_hidden)这个参数就是隐藏神经元的个数。
例如:
import tensorflow as tf
batch_size = 4
input = tf.random_normal(shape=[3, batch_size, 6], dtype=tf.float32)
cell = tf.nn.rnn_cell.BasicRNNCell(10)
init_state = cell.zero_state(batch_size, dtype=tf.float32)
output, final_state = tf.nn.dynamic_rnn(cell, input, initial_state=init_state, time_major=True)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(output).shape)
print(sess.run(final_state))
输出:
[[[-0.36194739 -0.5664643 0.71341908 -0.56548703 -0.6058557 0.15607478
-0.10932037 -0.76532066 -0.15569483 0.5749777 ]
[-0.47865775 -0.85153252 0.11955925 -0.47678211 -0.26779744 -0.16315795
-0.85670316 -0.29747197 -0.74362296 -0.11782304]
[-0.35551894 -0.22971147 0.87532502 -0.07564095 -0.3109358 0.40605015
0.42417526 -0.73830104 0.63733381 0.29208559]
[ 0.17648573 0.90195322 -0.66908085 0.62597507 0.36367226 -0.41078696
0.82797366 0.7970643 0.48825309 -0.09092989]]
[[-0.07066768 -0.58495802 -0.83810818 -0.22170046 0.28530884 -0.48797613
-0.86179078 0.53488874 -0.30063036 -0.19637674]
[-0.24391828 0.08400524 -0.89338982 -0.31769255 -0.80121225 -0.66595536
-0.6133672 -0.19677906 0.30365667 0.23569871]
[ 0.55269027 0.57405007 -0.66748625 0.11129615 0.4685905 -0.31985056
0.37982267 0.60275972 -0.28347531 0.81068254]
[-0.30811819 -0.46662089 -0.5317077 -0.44609445 0.11240361 -0.48326215
-0.68652773 -0.73142618 0.45866293 -0.50407058]]
[[-0.6940583 0.51343572 -0.54493487 -0.73246908 -0.96255547 -0.51650691
0.32794529 -0.7064063 -0.6840449 0.40109596]
[-0.48864028 -0.63549376 -0.30771643 -0.30445376 0.278009 -0.08625165
-0.59299129 0.2232109 -0.85149229 -0.77802432]
[-0.34051719 -0.16116273 -0.69728005 -0.46142533 0.28736579 -0.46011281
-0.35864782 0.17567375 0.26353961 -0.81816453]
[-0.05024701 -0.38233131 0.31979072 -0.36023989 -0.56383204 0.55900681
0.13344656 -0.35502923 -0.88185179 0.06796031]]]
[[ 0.05491487 0.58953148 0.73639983 -0.67890507 0.37539703 -0.17965442
0.85910887 0.02395871 0.24805558 0.41260818]
[ 0.08956354 0.80021787 -0.77623397 0.11130377 0.04178546 -0.36267385
0.40600157 0.84655112 0.14686334 0.24669757]
[ 0.51759291 0.67971212 0.16224268 0.77545607 0.12878148 0.77131855
-0.35831013 -0.73500431 0.53704679 -0.18410714]
[ 0.11651993 0.4543218 -0.51284182 0.72710162 0.24540463 -0.30895576
-0.4665778 -0.0014685 0.91048062 -0.22179691]]
另外函数的用法可以参考这个:http://blog.youkuaiyun.com/uestc_c2_403/article/details/73353145