LSTM 详细公式与图解

Gated Recurrent Unit

GRU 和 LSTM 都可以很好的解决 RNN 中的梯度消失问题,而 GRU 与 LSTM 在某些方面很相似,为了阐述 LSTM,先阐述 GRU。

下图所示是普通 RNN 单元

GRU 的 RNN 单元与其类似,但有所不同,其中对于 a 的计算分为三部:

  1. 计算 a~⟨t⟩=tanh(wa[a⟨t−1⟩,x⟨t⟩]+ba)\tilde{a}^{\langle t \rangle} = tanh(w_a[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_a)a~t=tanh(wa[at1,xt]+ba)
  2. 计算 Γu=σ(wu[a⟨t−1⟩,x⟨t⟩]+bu)\Gamma_u = \sigma(w_u[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_u)Γu=σ(wu[at1,xt]+bu)
  3. 最终 a⟨t⟩=Γu⋅a~⟨t⟩+(1−Γu)⋅a⟨t−1⟩a^{\langle t \rangle} = \Gamma_u \cdot \tilde{a}^{\langle t \rangle} + (1-\Gamma_u) \cdot a^{\langle {t-1} \rangle}at=Γua~t+(1Γu)at1

其中 Γu\Gamma_uΓuupdate gate,即更新门,其值域为 [0,1][0, 1][0,1]. 从上式可以看出,最终的 a⟨t⟩a^{\langle t \rangle}at 是当前激活值与一个时间步骤前的激活值的线性组合,通过这种方式,可以使得先前激活值有一定概率传播到当前激活值,即记住了句子之前的信息。然后用最终的 a⟨t⟩a^{\langle t \rangle}at 计算 y⟨t⟩y^{\langle t \rangle}yt.

另外为了与普通 RNN 单元进行区分,GRU 中的激活值一般以 c 表示,将上式中的 a 替换为 c 即可,下面将使用 c 阐述其他内容。

目前为止介绍的 GRU 其实做了简化,完整的 GRU 还有一个相关门,即 relevant gate,用来确定 c⟨t−1⟩c^{\langle {t-1} \rangle}ct1c⟨t⟩c^{\langle t \rangle}ct 的相关程度,加入了更新门后对于 c 的计算过程如下。

  1. 计算 Γr=σ(wr[c⟨t−1⟩,x⟨t⟩]+br)\Gamma_r = \sigma(w_r[c^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_r)Γr=σ(wr[ct1,xt]+br)
  2. 计算 c~⟨t⟩=tanh(wc[Γr⋅c⟨t−1⟩,x⟨t⟩]+bc)\tilde{c}^{\langle{t}\rangle} = tanh(w_c[\Gamma_r \cdot c^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_c)c~t=tanh(wc[Γrct1,xt]+bc)
  3. 计算 Γu=σ(wu[c⟨t−1⟩,x⟨t⟩]+bu)\Gamma_u = \sigma(w_u[c^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_u)Γu=σ(wu[ct1,xt]+bu)
  4. 最终 c⟨t⟩=Γu⋅c~⟨t⟩+(1−Γu)⋅c⟨t−1⟩c^{\langle t \rangle} = \Gamma_u \cdot \tilde{c}^{\langle t \rangle} + (1-\Gamma_u) \cdot c^{\langle {t-1} \rangle}ct=Γuc~t+(1Γu)ct1

其中第一步计算 relevant gate,第二步利用 relevatn gate 计算 c~⟨t⟩\tilde{c}^{\langle t \rangle}c~t,其余步骤与之前相同。

关于为什么要使用这样的 RNN 单元,Andrew NG 对此有下面这一番话:

So why we use these architectures, why don’t we change them, how we know they will work, why not add another gate, why not use the simpler GRU instead of the full GRU; well researchers has experimented over years all the various types of these architectures with many many different versions and also addressing the vanishing gradient problem. They have found that full GRUs are one of the best RNN architectures to be used for many different problems. You can make your design but put in mind that GRUs and LSTMs are standards.

Long Short Term Memory

GRU 在解决梯度消失问题上的表现很不错,但在 GRU 提出之前,LSTM 存在已久,而 LSTM 比起 GRU 使用得更加普遍。

LSTM 与 GRU 十分相似。在 GRU 中,我们有 update gaterelevant gate,以及激活单元 c,而在 LSTM 中,没有_relevant gate_,但新增了 forget gateoutput gate,以及激活单元 c 和 a,下面我们来详细阐述。

在 GRU 中,我们使用 update gate 的来控制激活单元是否更新以及更新的程度,其目的是减少激活单元更新的次数或程度,好让之前的激活单元的值得到保留,换言之,记住句子前面部分的信息,这一点在 LSTM 中并没改变,只不过相比于 GRU 使用 1−Γu1 - \Gamma_u1Γu 来表示不更新的概率,LSTM 直接使用一个 forget gateΓf\Gamma_fΓf 来代替 1−Γu1 - \Gamma_u1Γu,下表可以清楚看出 GRU 与 LSTM 在计算 c⟨t⟩c^{\langle t \rangle}ct 时的区别。

GRULSTM
c⟨t⟩=Γu⋅c~⟨t⟩+(1−Γu)⋅c⟨t−1⟩c^{\langle t \rangle} = \Gamma_u \cdot \tilde{c}^{\langle t \rangle} + (1-\Gamma_u) \cdot c^{\langle{t-1}\rangle}ct=Γuc~t+(1Γu)ct1c⟨t⟩=Γu⋅c~⟨t⟩+Γf⋅c⟨t−1⟩c^{\langle t \rangle} = \Gamma_u \cdot \tilde{c}^{\langle t \rangle} + \Gamma_f \cdot c^{\langle{t-1}\rangle}ct=Γuc~t+Γfct1

另一个门,即 output gate 的作用是进一步控制激活单元更新的程度,在 GRU 中,上表算出的 c⟨t⟩c^{\langle t \rangle}ct 就是激活单元,而在 LSTM 中还需进一步计算,再用一张表表示。

GRU 的激活单元LSTM 的激活单元
c⟨t⟩=Γu⋅c~⟨t⟩+(1−Γu)⋅c⟨t−1⟩c^{\langle t \rangle} = \Gamma_u \cdot \tilde{c}^{\langle t \rangle} + (1-\Gamma_u) \cdot c^{\langle{t-1}\rangle}ct=Γuc~t+(1Γu)ct1c⟨t⟩=Γu⋅c~⟨t⟩+Γf⋅c⟨t−1⟩c^{\langle t \rangle} = \Gamma_u \cdot \tilde{c}^{\langle t \rangle} + \Gamma_f \cdot c^{\langle{t-1}\rangle}ct=Γuc~t+Γfct1
a⟨t⟩=Γo⋅tanh(c⟨t⟩)a^{\langle t \rangle} = \Gamma_o \cdot tanh(c^{\langle t \rangle})at=Γotanh(ct)

从表中可以看出,LSTM 的最终激活单元是 a,即 y⟨t⟩y^{\langle t \rangle}yt 是通过 a⟨t⟩a^{\langle t \rangle}at 的计算得出的,c⟨t⟩c^{\langle t \rangle}ct 只是中间变量,不过 c⟨t⟩c^{\langle t \rangle}cta⟨t⟩a^{\langle t \rangle}at 都会传向下一个单元,一会会用一张图表示这个过程。

介绍了 forget gateoutput gate 的作用后,让我们把 LSTM 的激活单元计算过程中涉及的计算式完整写一遍:

  1. c~⟨t⟩=tanh(wc[a⟨t−1⟩,x⟨t⟩]+bc)\tilde{c}^{\langle t \rangle} = tanh(w_c[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_c)c~t=tanh(wc[at1,xt]+bc)
  2. Γu=σ(wu[a⟨t−1⟩,x⟨t⟩]+bu)\Gamma_u = \sigma(w_u[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_u)Γu=σ(wu[at1,xt]+bu)
  3. Γf=σ(wf[a⟨t−1⟩,x⟨t⟩]+bf)\Gamma_f = \sigma(w_f[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_f)Γf=σ(wf[at1,xt]+bf)
  4. Γo=σ(wo[a⟨t−1⟩,x⟨t⟩]+bo)\Gamma_o = \sigma(w_o[a^{\langle t-1 \rangle}, x^{\langle t \rangle}] + b_o)Γo=σ(wo[at1,xt]+bo)
  5. c⟨t⟩=Γu⋅c~⟨t⟩+Γf⋅c⟨t−1⟩c^{\langle t \rangle} = \Gamma_u \cdot \tilde{c}^{\langle t \rangle} + \Gamma_f \cdot c^{\langle{t-1}\rangle}ct=Γuc~t+Γfct1
  6. a⟨t⟩=Γo⋅tanh(c⟨t⟩)a^{\langle t \rangle} = \Gamma_o \cdot tanh(c^{\langle t \rangle})at=Γotanh(ct)

计算顺序不一定按照上面的序号来。可以用一张图来表示 LSTM 的 RNN 单元的计算过程:

从图中可以看出,c⟨t⟩c^{\langle t \rangle}cta⟨t⟩a^{\langle t \rangle}at 都传向了下一个单元(这里说下一个单元有些不太准确,准确形容应该是 the next time step),但只有 a⟨t⟩a^{\langle t \rangle}at 参与了 y⟨t⟩y^{\langle t \rangle}yt 的计算。

下面这幅图展示了 LSTM 的前向传播过程。

### LSTM神经网络公式详细介绍 长短期记忆网络(Long Short-Term Memory, LSTM)是一种特殊的循环神经网络(Recurrent Neural Network, RNN),其核心在于能够通过门控机制有效解决梯度消失和长期依赖问题。以下是LSTM的关键组成部分及其对应的数学公式。 #### 遗忘门 (Forget Gate) 遗忘门决定了细胞状态中哪些部分会被保留或丢弃。它基于当前输入 \(x_t\) 和前一时刻隐藏状态 \(h_{t-1}\),计算一个介于0到1之间的值,表示每个单元的状态被遗忘的程度: \[ f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \] 其中: - \(W_f\) 是权重矩阵, - \(b_f\) 是偏置项, - \(\sigma\) 表示Sigmoid激活函数[^1]。 #### 输入门 (Input Gate) 输入门控制新信息进入细胞状态的程度。这一过程分为两步:第一步是决定更新的内容;第二步是决定更新的比例。 \[ i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \] \[ \tilde{C}_t = tanh(W_C \cdot [h_{t-1}, x_t] + b_C) \] 这里: - \(i_t\) 控制更新比例; - \(\tilde{C}_t\) 是候选的新细胞状态[^2]。 #### 细胞状态更新 (Cell State Update) 新的细胞状态由旧的细胞状态经过遗忘操作后加上按比例缩放后的候选状态构成: \[ C_t = f_t * C_{t-1} + i_t * \tilde{C}_t \] 这一步综合了历史信息当前时间步的信息[^3]。 #### 输出门 (Output Gate) 最后,输出门决定了本时刻的输出如何生成。该输出基于更新后的细胞状态并受控于另一个Sigmoid层: \[ o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \] \[ h_t = o_t * tanh(C_t) \] 以上即为标准LSTM模型的主要公式描述。 ```python import numpy as np def sigmoid(x): return 1 / (1 + np.exp(-x)) def lstm_cell(input_data, prev_hidden_state, cell_state, weights): ft = sigmoid(np.dot(weights['wf'], input_data) + weights['bf']) it = sigmoid(np.dot(weights['wi'], input_data) + weights['bi']) Ct_hat = np.tanh(np.dot(weights['wc'], input_data) + weights['bc']) ot = sigmoid(np.dot(weights['wo'], input_data) + weights['bo']) new_cell_state = ft * cell_state + it * Ct_hat current_hidden_state = ot * np.tanh(new_cell_state) return current_hidden_state, new_cell_state ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值