TensorFlow2:RNN、LSTM、GRU

本文详细介绍了在TensorFlow2的keras框架下,SimpleRNN、LSTM和GRU层的使用方法,特别是它们的输出形式。RNN通常取最后一个时间步的输出,而LSTM和GRU可以返回每个时间步的状态。通过设置return_sequences和return_state参数,可以获取全部时间步的输出和内部状态。示例代码展示了如何实现这些操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

TensorFlow2中keras框架下layer对象中封装了大量常见循环神经网络层,如keras.layer.SimpleRNN、keras.layer.RNNcell、keras.layer.LSTM、keras.layer.LSTMcell等等类,其中keras.layer.SimpleRnn、keras.layer.LSTM、keras.layer.GRU类就是我们常说的RNN、LSTM、GRU在TensorFlow2对应的函数,下面对几种循环神经网络的输入输出简单介绍

不同的任务,循环神经网络的输出不同,有时取最后一个时间的输出即可,有时要用到全部时间步的输出

RNN

输出最后一个时间步

输入[batch_size,time_step,hidden_nodes]
输出[batch_size,time_step,hidden_nodes]/[batch_size,hidden_nodes]

inputs = np.random.random([32, 10, 8]).astype(np.float32)
simple_rnn = tf.keras.layers.SimpleRNN(4)
output = simple_rnn(inputs)  

输出维度

The output has shape [32, 4].

输出每个时间步状态

simple_rnn = tf.keras.layers.SimpleRNN(
    4, return_sequences=True, return_state=True)
whole_sequence_output, final_state = simple_rnn(inputs)

输出维度

whole_sequence_output has shape [32, 10, 4].
final_state has shape [32, 4].

LSTM

从图中可以看出LSTM的会得到一个状态 c t c_t ct经LSTM的输出门最终得到 h t h_t ht

输出最后一个时间步

inputs = tf.random.normal([32, 10, 8])
lstm = tf.keras.layers.LSTM(4)
output = lstm(inputs)
print(output.shape)

输出维度

The output has shape [32, 4].

获取每个时间步状态、输出

lstm = tf.keras.layers.LSTM(4, return_sequences=True, return_state=True)
whole_seq_output, final_memory_state, final_carry_state = lstm(inputs)
print(whole_seq_output.shape)
print(final_memory_state.shape)
print(final_carry_state.shape)

输出维度

whole_sequence_output has shape [32, 10, 4].
final_memory_state has shape [32, 4].
final_carry_state has shape [32, 4].

GRU

输出最后一个时间步

inputs = tf.random.normal([32, 10, 8])
gru = tf.keras.layers.GRU(4)
output = gru(inputs)
print(output.shape)

输出维度

The output has shape(32,4)

输出每个时间步状态

gru = tf.keras.layers.GRU(4, return_sequences=True, return_state=True)
whole_sequence_output, final_state = gru(inputs)
print(whole_sequence_output.shape)
print(final_state.shape)

输出维度

whole_sequence_output has shape(32,10,4)
final_state.shape(32,4)

完整代码

其中RNN的完整训练代码

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值