https://blog.youkuaiyun.com/u010960155/article/details/81707498
输入的形状为 [ 3, 6, 4 ],经过tf.nn.dynamic_rnn后outputs的形状为 [ 3, 6, 5 ],state形状为 [ 2, 3, 5 ],其中state第一部分为c,代表cell state;第二部分为h,代表hidden state。可以看到hidden state 与 对应的outputs的最后一行是相等的
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
def dynamic_rnn(rnn_type='lstm'):
# 创建输入数据,3代表batch size,6代表输入序列的最大步长(max time),8代表每个序列的维度
X = np.random.randn(3, 6, 4)
# 第二个输入的实际长度为4
X[1, 4:] = 0
#记录三个输入的实际步长
X_lengths = [6, 4, 6]
rnn_hidden_size = 5
if rnn_type == 'lstm':
cell = tf.contrib.rnn.BasicLSTMCell(num_units=rnn_hidden_size, state_is_tuple=True)
else:
cell = tf.contrib.rnn.GRUCell(num_units=rnn_hidden_size)
outputs, last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=X_lengths,
inputs=X)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
o1, s1 = session.run([outputs, last_states])
print(np.shape(o1))
print(o1)
print(np.shape(s1))
print(s1)
if __name__ == '__main__':
dynamic_rnn(rnn_type='lstm')
# result as follows:
(3, 6, 5)
[[[-0.03649305 -0.08759684 -0.14881246 0.03393637 0.070974 ]
[ 0.07188818 -0.06839971 -0.25666403 0.12532337 0.0098207 ]
[-0.09513322 -0.13559747 0.03649189 0.05972557 -0.01364573]
[-0.17752969 -0.1838733 0.16408029 -0.02859155 -0.02507258]
[-0.06381911 -0.20662363 0.06501477 -0.04897918 0.20974965]
[-0.05599105 -0.42127601 0.16586181 -0.11377841 0.05137224]]
[[ 0.09163218 0.17676768 -0.05854728 0.03455833 -0.15493153]
[ 0.14222002 0.01853503 -0.10878392 0.01718411 0.08146301]
[ 0.22258063 -0.05167067 0.01259944 -0.00889279 0.09269447]
[ 0.10070281 -0.10465053 -0.06934761 0.02947036 0.11592762]
[ 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. ]]
[[ 0.0886997 -0.00751986 -0.03902694 0.19838859 -0.07955689]
[ 0.11569981 -0.12117248 -0.00385077 0.00903499 0.07933235]
[ 0.27478295 -0.09153444 0.06287552 0.13007742 -0.13177781]
[ 0.21080878 -0.07592999 0.12971802 -0.01637112 -0.04625642]
[ 0.00913372 0.03619587 0.1536368 0.03764711 -0.1069617 ]
[ 0.16520677 -0.12995608 0.23450888 0.15317401 -0.06651008]]]
(2, 3, 5)
LSTMStateTuple(c=array([[-0.06978183, -0.59760795, 0.47491219, -0.2818135 , 0.13637446],
[ 0.22598655, -0.20673746, -0.11502892, 0.05775042, 0.24337772],
[ 0.27777204, -0.17043078, 0.33604336, 0.29077832, -0.14329689]]), h=array([[-0.05599105, -0.42127601, 0.16586181, -0.11377841, 0.05137224],
[ 0.10070281, -0.10465053, -0.06934761, 0.02947036, 0.11592762],
[ 0.16520677, -0.12995608, 0.23450888, 0.15317401, -0.06651008]]))