tf.nn.dynamic_rnn的输出outputs和state含义

https://blog.youkuaiyun.com/u010960155/article/details/81707498

输入的形状为 [ 3, 6, 4 ],经过tf.nn.dynamic_rnn后outputs的形状为 [ 3, 6, 5 ],state形状为 [ 2, 3, 5 ],其中state第一部分为c,代表cell state;第二部分为h,代表hidden state。可以看到hidden state 与 对应的outputs的最后一行是相等的
 

import tensorflow as tf
import numpy as np
tf.reset_default_graph() 
def dynamic_rnn(rnn_type='lstm'):
    # 创建输入数据,3代表batch size,6代表输入序列的最大步长(max time),8代表每个序列的维度
    X = np.random.randn(3, 6, 4)
 
    # 第二个输入的实际长度为4
    X[1, 4:] = 0
 
    #记录三个输入的实际步长
    X_lengths = [6, 4, 6]
 
    rnn_hidden_size = 5
    if rnn_type == 'lstm':
        cell = tf.contrib.rnn.BasicLSTMCell(num_units=rnn_hidden_size, state_is_tuple=True)
    else:
        cell = tf.contrib.rnn.GRUCell(num_units=rnn_hidden_size)
 
    outputs, last_states = tf.nn.dynamic_rnn(
        cell=cell,
        dtype=tf.float64,
        sequence_length=X_lengths,
        inputs=X)
 
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        o1, s1 = session.run([outputs, last_states])
        print(np.shape(o1))
        print(o1)
        print(np.shape(s1))
        print(s1)
 
 
if __name__ == '__main__':
    dynamic_rnn(rnn_type='lstm')
# result as follows:

(3, 6, 5)
[[[-0.03649305 -0.08759684 -0.14881246  0.03393637  0.070974  ]
  [ 0.07188818 -0.06839971 -0.25666403  0.12532337  0.0098207 ]
  [-0.09513322 -0.13559747  0.03649189  0.05972557 -0.01364573]
  [-0.17752969 -0.1838733   0.16408029 -0.02859155 -0.02507258]
  [-0.06381911 -0.20662363  0.06501477 -0.04897918  0.20974965]
  [-0.05599105 -0.42127601  0.16586181 -0.11377841  0.05137224]]

 [[ 0.09163218  0.17676768 -0.05854728  0.03455833 -0.15493153]
  [ 0.14222002  0.01853503 -0.10878392  0.01718411  0.08146301]
  [ 0.22258063 -0.05167067  0.01259944 -0.00889279  0.09269447]
  [ 0.10070281 -0.10465053 -0.06934761  0.02947036  0.11592762]
  [ 0.          0.          0.          0.          0.        ]
  [ 0.          0.          0.          0.          0.        ]]

 [[ 0.0886997  -0.00751986 -0.03902694  0.19838859 -0.07955689]
  [ 0.11569981 -0.12117248 -0.00385077  0.00903499  0.07933235]
  [ 0.27478295 -0.09153444  0.06287552  0.13007742 -0.13177781]
  [ 0.21080878 -0.07592999  0.12971802 -0.01637112 -0.04625642]
  [ 0.00913372  0.03619587  0.1536368   0.03764711 -0.1069617 ]
  [ 0.16520677 -0.12995608  0.23450888  0.15317401 -0.06651008]]]
(2, 3, 5)
LSTMStateTuple(c=array([[-0.06978183, -0.59760795,  0.47491219, -0.2818135 ,  0.13637446],
       [ 0.22598655, -0.20673746, -0.11502892,  0.05775042,  0.24337772],
       [ 0.27777204, -0.17043078,  0.33604336,  0.29077832, -0.14329689]]), h=array([[-0.05599105, -0.42127601,  0.16586181, -0.11377841,  0.05137224],
       [ 0.10070281, -0.10465053, -0.06934761,  0.02947036,  0.11592762],
       [ 0.16520677, -0.12995608,  0.23450888,  0.15317401, -0.06651008]]))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值