从头构建LSTM

本文深入浅出地介绍了LSTM(长短期记忆网络)的工作原理,包括naive RNN的基础,LSTM的前向传播、损失函数和反向传播过程。通过详细的公式解析,阐述了LSTM如何在记忆和遗忘机制下处理序列数据,以及在训练过程中如何优化损失函数。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


version:20210903

参考链接:

  1. 人人都能看懂的LSTM
  2. Simple LSTM-Nico’s blog
  3. github源码参考

自己阅读相关材料时整理的笔记,梳理知识点和思路用,仅供参考。(后面用空再补充代码部分的解析)

1. 概念理解

公式说明:

ϕ ( x ) = t a n h ( x ) ϕ(x)=tanh(x) ϕ(x)=tanh(x)
σ ( x ) = 1 1 + e − x σ(x)=\frac{1}{1+e^{−x}} σ(x)=1+ex1

t a n h tanh tanh函数取值范围为[-1,1];
s i g m o i d sigmoid sigmoid函数取值范围为(0,1)。

1.1. naive RNN

h ′ , y = f ( h , x ) h', y=f(h,x) h,y=f(h,x)
h ′ = σ ( w h + w i x ) h'=\sigma(w^h+w^ix) h=σ(wh+wix)

y = σ ( w o h ′ ) y=\sigma(w^oh') y=σ(woh)

其中:

x x x为当前节点状态下的输入, h h h表示接收到上一个节点的输入;

y y y为当前节点状态下的输出, h ′ h' h为传递到下一个节点的输出;

h ′ h' h x x x h h h的值都相关;

y y y则常常使用 h ′ h' h投入到一个线性层(主要是进行维度映射)然后使用softmax进行分类得到需要的数据。(往往看具体模型的使用方式)

1.2. LSTM

1.2.1. 前向forward

相比RNN只有一个传递状态 h t h^t ht,LSTM有两个传输状态,一个 s t s^t st(cell state),和一个 h t h^t ht(hidden state)。

s t , h t , y t = f ( s t − 1 , h t − 1 , x t ) s^t, h^t, y^t=f(s^{t-1}, h^{t-1}, x^t) st,ht,yt=f(st1,ht1,xt)

f t = σ ( W f x x t + W f h h t − 1 + b f ) f^t=σ(W_{fx}x^t+W_{fh}h^{t−1}+b_f) ft=σ(Wfxxt+Wfhht1+bf)

i t = σ ( W i x x t + W i h h t − 1 + b i ) i^t=σ(W_{ix}x^t+W_{ih}h^{t−1}+b_i) it=σ(Wixxt+Wihht1+bi)

o t = σ ( W o x x t + W o h h t − 1 + b o ) o^t=σ(W_{ox}x^t+W_{oh}h^{t−1}+b_o) ot=σ(Woxxt+Wohht1+bo)

g t = ϕ ( W g x x t + W g h h t − 1 + b g ) g^t=ϕ(W_{gx}x^t+W_{gh}h^{t−1}+b_g) gt=ϕ(Wgxxt+Wghht1+bg)

这部分在实际计算中,会将 x t x^t xt h t − 1 h^{t−1} ht1组合起来 x c t = [ x t , h t − 1 ] x_c^t=[x^t,h^{t−1}] xct=[xt,ht1],然后一起计算,故公式可以简写为如下形式:

f t = σ ( W f x c t + b f ) f^t=σ(W_{f}x_c^t+b_f) ft=σ(Wfxct+bf)

i t = σ ( W i x c t + b i ) i^t=σ(W_{i}x_c^t+b_i) it=σ(Wixct+bi)

o t = σ ( W o x c t + b o ) o^t=σ(W_{o}x_c^t+b_o) ot=σ(Woxct+bo)

g t = ϕ ( W g x c t + b g ) g^t=ϕ(W_{g}x_c^t+b_g) gt=ϕ(Wgxct+bg)

s t = g t ∗ i t + s t − 1 ∗ f t s^t=g^t∗i^t+s^{t−1}∗f^t st=gtit+st1ft

h t = o t ∗ s t h^t=o^t*s^t ht=otst

y ^ t = σ ( w ′ h t ) \hat{y}^t=\sigma(w'h^t) y^t=σ(wht)

Tips:

RNN中的 h t h^t ht对于LSTM中的 s t s^t st

s t s^t st改变得很慢,通常输出的 s t s^t st是上一个状态传过来的 s t − 1 s^{t-1} st1加上一些数值;

h t h^t ht则主要依赖当前节点的数据,所以在不同节点往往会有很大的区别。

其中:

  1. f f f表示forget,为忘记阶段。这个阶段主要是对上一个节点传进来的输入 s t − 1 s^{t−1} st1进行选择性忘记 f t f^t ft

  2. i i i代表input,为选择记忆阶段。这个阶段将这个阶段的输入 g t g^t gt(对原始的 x t x^t xt进行了tanh激活)有选择性地进行记忆 i t i^t it;

(将上面两步得到的结果相加,即可得到传输给下一个状态的 s t s^t st

  1. o o o代表output,主要控制输出阶段,这个阶段将决定哪些将会被当成当前状态的输出,主要是通过 o t o^t ot来进行控制。(这里未对上一阶段得到的 s t s^t st进行放缩,有需要还可以通过一个tanh激活函数进行变化: h t = o t ∗ ϕ ( s t ) h^t=o^t*ϕ(s^t) ht=otϕ(st)

  2. 输出 y ^ t \hat{y}^t y^t与普通RNN类似,往往最终也是通过 h t h^t ht变化得到。这里假设 y ^ t \hat{y}^t y^t= h t h^t ht,则下面的 y ^ t \hat{y}^t y^t直接写为 h t h^t ht

f f f i i i o o o都是门控(gate),使用 σ σ σ激活; g g g是作为输入数据的,不是门控状态,所以用 ϕ ϕ ϕ激活。

1.2.2. 损失函数lossFunc

定义每个时间步 t t t的损失函数为:

l ( t ) = f ( h ( t ) , y ( t ) ) l(t)=f(h(t),y(t)) l(t)=f(h(t),y(t))(1)

这里选用L2范数损失函数,也叫欧几里得损失函数,来计算loss,公式如下:

l ( t ) = f ( h ( t ) , y ( t ) ) = ∥ h ( t ) − y ( t ) ∥ 2 l(t)=f(h(t),y(t))=∥h(t)−y(t)∥^2 l(t)=f(h(t),y(t))=h(t)y(t)2

最终目标是通过梯度下降来使整个长度为 T T T的序列的损失 L L L最小化:

L = ∑ t = 1 T l ( t ) L=\sum_{t=1}^Tl(t) L=t=1Tl(t)

1.2.3. 反向传播backpropagation

下面来推导loss梯度:

d L d w \frac{dL}{dw} dwdL

∵ w \because w w是标量参数;且由(1)可知损失只与隐含层 h ( t ) h(t) h(t)和标签 y ( t ) y(t) y(t)有关;由链式法则

∴ d L d w = ∑ t = 1 T ∑ i = 1 M d L d h i ( t ) d h i ( t ) d w \therefore \frac{dL}{dw}=∑_{t=1}^T∑_{i=1}^M\frac{dL}{dh_i(t)}\frac{dh_i(t)}{dw} dwdL=t=1Ti=1Mdhi(t)dLdwdhi(t)(2)

其中 h i ( t ) h_i(t) hi(t)是一个标量,是第 i i i个memory cell的隐含层的输出, M M M是memory cell的总数。在网络中信息会随着时间向前传播,在时间 t t t,改变 h i ( t ) h_i(t) hi(t) t t t之前的损失没有什么影响,所以公式可以写成如下:

d L d h i ( t ) = ∑ s = 1 T d l ( s ) d h i ( t ) = ∑ s = t T d l ( s ) d h i ( t ) \frac{dL}{dh_i(t)}=∑_{s=1}^T\frac{dl(s)}{dh_i(t)}=∑_{s=t}^T\frac{dl(s)}{dh_i(t)} dhi(t)dL=s=1Tdhi(t)dl(s)=s=tTdhi(t)dl(s)(3)

为了方便,我们使用 L ( t ) L(t) L(t)来表示从 t t t开始的累计损失:

L ( t ) = ∑ s = t T l ( s ) L(t)=∑_{s=t}^{T}l(s) L(t)=s=tTl(s)(4)

所以,当 t = 1 t=1 t=1时, L ( 1 ) L(1) L(1)则表示整个序列的损失。故(3)可以写为:

d L d h i ( t ) = ∑ s = t T d l ( s ) d h i ( t ) = d L ( t ) d h i ( t ) \frac{dL}{dh_i(t)}=∑_{s=t}^T\frac{dl(s)}{dh_i(t)}=\frac{dL(t)}{dh_i(t)} dhi(t)dL=s=tTdhi(t)dl(s)=dhi(t)dL(t)

(2)可以写为:

d L d w = ∑ t = 1 T ∑ i = 1 M d L ( t ) d h i ( t ) d h i ( t ) d w \frac{dL}{dw}=∑_{t=1}^T∑_{i=1}^M\frac{dL(t)}{dh_i(t)}\frac{dh_i(t)}{dw} dwdL=t=1Ti=1Mdhi(t)dL(t)dwdhi(t)

d h i ( t ) d w \frac{dh_i(t)}{dw} dwdhi(t)部分就按照前向传播的公式去推导,下面介绍如何计算 d L ( t ) d h i ( t ) \frac{dL(t)}{dh_i(t)} dhi(t)dL(t)部分。

由(4)可得:

L ( t ) = { l ( t ) + L ( t + 1 ) , i f t < T l ( t ) , i f t = T L(t)=\begin{cases} l(t)+L(t+1), & if\quad t<T\\ l(t), & if\quad t=T \end{cases} L(t)={l(t)+L(t+1),l(t),ift<Tift=T

由此,给定时间 t t t和一个LSTM节点的 h ( t ) h(t) h(t),可得:

d L ( t ) d h ( t ) = d l ( t ) d h ( t ) + d L ( t + 1 ) d h ( t ) \frac{dL(t)}{dh(t)}=\frac{dl(t)}{dh(t)}+\frac{dL(t+1)}{dh(t)} dh(t)dL(t)=dh(t)dl(t)+dh(t)dL(t+1)

其中,前半部分 d l ( t ) d h ( t ) \frac{dl(t)}{dh(t)} dh(t)dl(t) h ( t ) h(t) h(t)在时间 t t t的损失 l ( t ) l(t) l(t)的求导;后半部分则体现了LSTM的recurrent性质。我们需要下一节点的derivative来计算当前节点的derivative。最终我们可以从 d L ( T ) d h ( T ) = d l ( T ) d h ( T ) \frac{dL(T)}{dh(T)}=\frac{dl(T)}{dh(T)} dh(T)dL(T)=dh(T)dl(T)开始,计算每个时间节点 t = 1 , … , T t=1,\dots,T t=1,,T,即 d L ( t ) d h ( t ) \frac{dL(t)}{dh(t)} dh(t)dL(t)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值