LSTM 详细公式与图解

Gated Recurrent Unit

GRU 和 LSTM 都可以很好的解决 RNN 中的梯度消失问题,而 GRU 与 LSTM 在某些方面很相似,为了阐述 LSTM,先阐述 GRU。

下图所示是普通 RNN 单元

GRU 的 RNN 单元与其类似,但有所不同,其中对于 a 的计算分为三部:

  1. 计算 a~⟨t⟩=tanh(wa[a⟨t−1⟩,x⟨t⟩]+ba)\tilde{a}^{\langle t \rangle} = tanh(w_a[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_a)a~t=tanh(wa[at1,xt]+ba)
  2. 计算 Γu=σ(wu[a⟨t−1⟩,x⟨t⟩]+bu)\Gamma_u = \sigma(w_u[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_u)Γu=σ(wu[at1,xt]+bu)
  3. 最终 a⟨t⟩=Γu⋅a~⟨t⟩+(1−Γu)⋅a⟨t−1⟩a^{\langle t \rangle} = \Gamma_u \cdot \tilde{a}^{\langle t \rangle} + (1-\Gamma_u) \cdot a^{\langle {t-1} \rangle}at=Γua~t+(1Γu)at1

其中 Γu\Gamma_uΓuupdate gate,即更新门,其值域为 [0,1][0, 1][0,1]. 从上式可以看出,最终的 a⟨t⟩a^{\langle t \rangle}at 是当前激活值与一个时间步骤前的激活值的线性组合,通过这种方式,可以使得先前激活值有一定概率传播到当前激活值,即记住了句子之前的信息。然后用最终的 a⟨t⟩a^{\langle t \rangle}at 计算 y⟨t⟩y^{\langle t \rangle}yt.

另外为了与普通 RNN 单元进行区分,GRU 中的激活值一般以 c 表示,将上式中的 a 替换为 c 即可,下面将使用 c 阐述其他内容。

目前为止介绍的 GRU 其实做了简化,完整的 GRU 还有一个相关门,即 relevant gate,用来确定 c⟨t−1⟩c^{\langle {t-1} \rangle}ct1c⟨t⟩c^{\langle t \rangle}ct 的相关程度,加入了更新门后对于 c 的计算过程如下。

  1. 计算 Γr=σ(wr[c⟨t−1⟩,x⟨t⟩]+br)\Gamma_r = \sigma(w_r[c^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_r)Γr=σ(wr[ct1,xt]+br)
  2. 计算 c~⟨t⟩=tanh(wc[Γr⋅c⟨t−1⟩,x⟨t⟩]+bc)\tilde{c}^{\langle{t}\rangle} = tanh(w_c[\Gamma_r \cdot c^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_c)c~t=tanh(wc[Γrct1,xt]+bc)
  3. 计算 Γu=σ(wu[c⟨t−1⟩,x⟨t⟩]+bu)\Gamma_u = \sigma(w_u[c^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_u)Γu=σ(wu[ct1,xt]+bu)
  4. 最终 c⟨t⟩=Γu⋅c~⟨t⟩+(1−Γu)⋅c⟨t−1⟩c^{\langle t \rangle} = \Gamma_u \cdot \tilde{c}^{\langle t \rangle} + (1-\Gamma_u) \cdot c^{\langle {t-1} \rangle}ct=Γuc~t+(1Γu)ct1

其中第一步计算 relevant gate,第二步利用 relevatn gate 计算 c~⟨t⟩\tilde{c}^{\langle t \rangle}c~t,其余步骤与之前相同。

关于为什么要使用这样的 RNN 单元,Andrew NG 对此有下面这一番话:

So why we use these architectures, why don’t we change them, how we know they will work, why not add another gate, why not use the simpler GRU instead of the full GRU; well researchers has experimented over years all the various types of these architectures with many many different versions and also addressing the vanishing gradient problem. They have found that full GRUs are one of the best RNN architectures to be used for many different problems. You can make your design but put in mind that GRUs and LSTMs are standards.

Long Short Term Memory

GRU 在解决梯度消失问题上的表现很不错,但在 GRU 提出之前,LSTM 存在已久,而 LSTM 比起 GRU 使用得更加普遍。

LSTM 与 GRU 十分相似。在 GRU 中,我们有 update gaterelevant gate,以及激活单元 c,而在 LSTM 中,没有_relevant gate_,但新增了 forget gateoutput gate,以及激活单元 c 和 a,下面我们来详细阐述。

在 GRU 中,我们使用 update gate 的来控制激活单元是否更新以及更新的程度,其目的是减少激活单元更新的次数或程度,好让之前的激活单元的值得到保留,换言之,记住句子前面部分的信息,这一点在 LSTM 中并没改变,只不过相比于 GRU 使用 1−Γu1 - \Gamma_u1Γu 来表示不更新的概率,LSTM 直接使用一个 forget gateΓf\Gamma_fΓf 来代替 1−Γu1 - \Gamma_u1Γu,下表可以清楚看出 GRU 与 LSTM 在计算 c⟨t⟩c^{\langle t \rangle}ct 时的区别。

GRULSTM
c⟨t⟩=Γu⋅c~⟨t⟩+(1−Γu)⋅c⟨t−1⟩c^{\langle t \rangle} = \Gamma_u \cdot \tilde{c}^{\langle t \rangle} + (1-\Gamma_u) \cdot c^{\langle{t-1}\rangle}ct=Γuc~t+(1Γu)ct1c⟨t⟩=Γu⋅c~⟨t⟩+Γf⋅c⟨t−1⟩c^{\langle t \rangle} = \Gamma_u \cdot \tilde{c}^{\langle t \rangle} + \Gamma_f \cdot c^{\langle{t-1}\rangle}ct=Γuc~t+Γfct1

另一个门,即 output gate 的作用是进一步控制激活单元更新的程度,在 GRU 中,上表算出的 c⟨t⟩c^{\langle t \rangle}ct 就是激活单元,而在 LSTM 中还需进一步计算,再用一张表表示。

GRU 的激活单元LSTM 的激活单元
c⟨t⟩=Γu⋅c~⟨t⟩+(1−Γu)⋅c⟨t−1⟩c^{\langle t \rangle} = \Gamma_u \cdot \tilde{c}^{\langle t \rangle} + (1-\Gamma_u) \cdot c^{\langle{t-1}\rangle}ct=Γuc~t+(1Γu)ct1c⟨t⟩=Γu⋅c~⟨t⟩+Γf⋅c⟨t−1⟩c^{\langle t \rangle} = \Gamma_u \cdot \tilde{c}^{\langle t \rangle} + \Gamma_f \cdot c^{\langle{t-1}\rangle}ct=Γuc~t+Γfct1
a⟨t⟩=Γo⋅tanh(c⟨t⟩)a^{\langle t \rangle} = \Gamma_o \cdot tanh(c^{\langle t \rangle})at=Γotanh(ct)

从表中可以看出,LSTM 的最终激活单元是 a,即 y⟨t⟩y^{\langle t \rangle}yt 是通过 a⟨t⟩a^{\langle t \rangle}at 的计算得出的,c⟨t⟩c^{\langle t \rangle}ct 只是中间变量,不过 c⟨t⟩c^{\langle t \rangle}cta⟨t⟩a^{\langle t \rangle}at 都会传向下一个单元,一会会用一张图表示这个过程。

介绍了 forget gateoutput gate 的作用后,让我们把 LSTM 的激活单元计算过程中涉及的计算式完整写一遍:

  1. c~⟨t⟩=tanh(wc[a⟨t−1⟩,x⟨t⟩]+bc)\tilde{c}^{\langle t \rangle} = tanh(w_c[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_c)c~t=tanh(wc[at1,xt]+bc)
  2. Γu=σ(wu[a⟨t−1⟩,x⟨t⟩]+bu)\Gamma_u = \sigma(w_u[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_u)Γu=σ(wu[at1,xt]+bu)
  3. Γf=σ(wf[a⟨t−1⟩,x⟨t⟩]+bf)\Gamma_f = \sigma(w_f[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_f)Γf=σ(wf[at1,xt]+bf)
  4. Γo=σ(wo[a⟨t−1⟩,x⟨t⟩]+bo)\Gamma_o = \sigma(w_o[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_o)Γo=σ(wo[at1,xt]+bo)
  5. c⟨t⟩=Γu⋅c~⟨t⟩+Γf⋅c⟨t−1⟩c^{\langle t \rangle} = \Gamma_u \cdot \tilde{c}^{\langle t \rangle} + \Gamma_f \cdot c^{\langle{t-1}\rangle}ct=Γuc~t+Γfct1
  6. a⟨t⟩=Γo⋅tanh(c⟨t⟩)a^{\langle t \rangle} = \Gamma_o \cdot tanh(c^{\langle t \rangle})at=Γotanh(ct)

计算顺序不一定按照上面的序号来。可以用一张图来表示 LSTM 的 RNN 单元的计算过程:

从图中可以看出,c⟨t⟩c^{\langle t \rangle}cta⟨t⟩a^{\langle t \rangle}at 都传向了下一个单元(这里说下一个单元有些不太准确,准确形容应该是 the next time step),但只有 a⟨t⟩a^{\langle t \rangle}at 参与了 y⟨t⟩y^{\langle t \rangle}yt 的计算。

下面这幅图展示了 LSTM 的前向传播过程。

### LSTM公式详解及数学表达式 长短期记忆网络(Long Short-Term Memory, LSTM)是一种特殊的循环神经网络(RNN),其核心在于通过门控机制控制信息的流动,从而解决传统RNN中的梯度消失和梯度爆炸问题。以下是LSTM的核心公式及其详细的数学表达。 #### 输入状态定义 在任意时间步 \( t \),LSTM接收三个主要输入:当前时刻的输入 \( x_t \)[^2]、上一时刻隐藏状态 \( h_{t-1} \) 和上一时刻单元状态 \( c_{t-1} \)。基于这些输入,LSTM计算新的隐藏状态 \( h_t \) 和单元状态 \( c_t \)。 --- #### 门控机制 LSTM的关键部分是三种不同的门控机制:遗忘门、输入门和输出门。每种门都由一个sigmoid层组成,用于决定哪些信息应该保留或丢弃。 ##### **遗忘门** 遗忘门决定了多少来自上一时刻的状态 \( c_{t-1} \) 应该被遗忘。它依赖于当前输入 \( x_t \) 和前一时刻的隐藏状态 \( h_{t-1} \): \[ f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \] 其中: - \( W_f \): 遗忘门权重矩阵; - \( b_f \): 偏置项; - \( \sigma \): Sigmoid激活函数。 --- ##### **输入门** 输入门负责更新细胞状态的部分新候选值的选择。这分为两部分: 1. 计算输入门的激活向量: \[ i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \] 2. 使用 tanh 层生成一个新的候选值向量 \( \tilde{c}_t \): \[ \tilde{c}_t = \text{tanh}(W_c \cdot [h_{t-1}, x_t] + b_c) \] 最终的新细胞状态 \( c_t \) 是旧状态 \( c_{t-1} \) 的一部分加上新候选值的一部分: \[ c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t \] 这里 \( \odot \) 表示逐元素乘法。 --- ##### **输出门** 输出门决定了下一隐藏状态 \( h_t \) 中要暴露给外部的信息。首先计算输出门的激活向量: \[ o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \] 然后利用经过 tanh 处理后的细胞状态 \( c_t \) 来生成最终的隐藏状态: \[ h_t = o_t \odot \text{tanh}(c_t) \] --- #### 总结公式 综合以上各部分,完整的LSTM公式如下所示: 1. 遗忘门: \[ f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \] 2. 输入门: \[ i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i), \quad \tilde{c}_t = \text{tanh}(W_c \cdot [h_{t-1}, x_t] + b_c) \] 3. 细胞状态更新: \[ c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t \] 4. 输出门: \[ o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o), \quad h_t = o_t \odot \text{tanh}(c_t) \] --- ### 实现代码示例 (PyTorch) 以下是一个简单的LSTM实现代码片段,展示了如何构建上述公式的张量操作: ```python import torch import torch.nn as nn class LSTMCell(nn.Module): def __init__(self, input_size, hidden_size): super(LSTMCell, self).__init__() self.input_size = input_size self.hidden_size = hidden_size # 定义权重参数 self.W_f = nn.Linear(input_size + hidden_size, hidden_size) self.W_i = nn.Linear(input_size + hidden_size, hidden_size) self.W_c = nn.Linear(input_size + hidden_size, hidden_size) self.W_o = nn.Linear(input_size + hidden_size, hidden_size) def forward(self, x, states): h_prev, c_prev = states combined_input = torch.cat((h_prev, x), dim=1) # 遗忘门 f_t = torch.sigmoid(self.W_f(combined_input)) # 输入门 i_t = torch.sigmoid(self.W_i(combined_input)) c_hat_t = torch.tanh(self.W_c(combined_input)) # 更新细胞状态 c_t = f_t * c_prev + i_t * c_hat_t # 输出门 o_t = torch.sigmoid(self.W_o(combined_input)) h_t = o_t * torch.tanh(c_t) return h_t, c_t ``` ---
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值