Keras中RNN LSTM循环神经网络中 Return Sequences 与 Return States 的区别

本文详细介绍了Keras中LSTM模型的return_sequences和return_state参数的作用。return_sequences返回每个时间步长的隐藏状态输出,而return_state仅返回最后一个时间步长的隐藏状态和单元状态。理解这两者有助于构建更复杂的RNN模型,如在序列到序列任务中。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Keras在GitHub上的一句话简介是“Deep Learning for humans”(给人用的深度学习工具),确实,基于Tensorflow的Keras相对前者来说,能够更加方便快捷地构建神经网络,使非计算机专业的使用者能够将更多精力放在研究问题本身,而不是构建神经网络模型的过程。

在Keras的循环神经网络(RNN)中,有两个参数return_statereturn_sequences

对于一般人来说容易混淆或者不清楚具体的用途用法,在经过一段时间的搜集与测试后,我把我的理解写在这篇博客上。

希望读者读完这篇博客后,可以知道:

  1. return_sequences 返回的是每一个输入时间步长(each input time step)的隐藏状态输出( the hidden state output);
  2. return_state 返回的是最后一个时间步长(the last input time step)的隐藏状态输出( the hidden state output)以及单元状态( cell state);
  3. return_sequencesreturn_state 两者可以同时使用。

要具体理解这两个参数的意思,我们首先需要了解一下RNN/LSTM的简单背景,这里只做简单介绍,具体细节还请参考具体的教科书或者论文。

RNN/LSTM 简介

最基本的循环神经网络RNN会遭遇“梯度消失”(vanishing gradients)的问题,因此很难捕捉长期的时间相关关系。
而Long Short-Term Memory(LSTM)能够解决这一问题,因为在每一个RNN单元中引入了“门”(gate)这个概念,使得模型能够通过反向传播(backpropagation)成功训练以表面梯度消失问题。
LSTM结构
在上图中可以看到对每一个RNN单元,有对应的输入数据X<t> (t = 1,2,3…),以及上一个单元的输出a<t-1> ,对每个单元,有隐藏状态输出( the hidden state output)以及单元状态( cell state)两个参数,之前提到的a<t> 就是隐藏状态输出,前一个单元的隐藏状态输出又会作为后一个单元的输入。对于每一个单元本身的单元状态( cell state),不同的RNN算法也有对应的计算方法。比如下图的GRU与LSTM,对于GRU, a<t> = c<t> ,而LSTM算法中,两者是不同的。
GRU/LSTM
上面的c<t> 的结果就是所谓的单元状态( cell state)。不用担心有点晕,其实只需要记住两点:

  1. a<t> 对应隐藏状态输出( the hidden state output)
  2. c<t> 对应单元状态( cell state)
    上面俩对应的就是之前提到的return_sequencesreturn_state这两个参数的设置。

Return sequences

Return sequences返回的是隐藏状态输出( the hidden state output)a<t>,默认设置为False,这意味着Keras只会输出最后一个隐藏层的结果,这个结果可以认为是整段输入序列的简单代表,有时候这正是我们想要的结果,比如一些特定的分类任务(如文本情感分析)或者回归任务(预测明天的股价)。
但是有的时候我们需要整个序列的输出,这时就需要把return_sequences设置为True。
下面用一个简单的例子来说明。
假设我们只有一个输入数据,这个数据有3个时间步长,最后的输出的也是一维数据,比如:

t1 = 0.1
t2 = 0.2
t3 = 0.3

所以我们的问题就可以简化成:输入3个连续的数字,得到一个数字。
完整的代码如下:

from keras.models import Model
from keras.layers import Input
from keras.layers import LSTM
from numpy import array
# define model
#shape=(3,1)代表输入数据的维度,问题是3个输入的数字,一个输出的数字,所以是(3,1)
inputs1 = Input(shape=(3, 1))
lstm1 = LSTM(1)(inputs1)
model = Model(inputs=inputs1, outputs=lstm1
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值