从头构建LSTM
version:20210903
参考链接:
自己阅读相关材料时整理的笔记,梳理知识点和思路用,仅供参考。(后面用空再补充代码部分的解析)
1. 概念理解
公式说明:
ϕ
(
x
)
=
t
a
n
h
(
x
)
ϕ(x)=tanh(x)
ϕ(x)=tanh(x)
σ
(
x
)
=
1
1
+
e
−
x
σ(x)=\frac{1}{1+e^{−x}}
σ(x)=1+e−x1
t a n h tanh tanh函数取值范围为[-1,1];
s i g m o i d sigmoid sigmoid函数取值范围为(0,1)。
1.1. naive RNN
h
′
,
y
=
f
(
h
,
x
)
h', y=f(h,x)
h′,y=f(h,x)
h
′
=
σ
(
w
h
+
w
i
x
)
h'=\sigma(w^h+w^ix)
h′=σ(wh+wix)
y = σ ( w o h ′ ) y=\sigma(w^oh') y=σ(woh′)
其中:
x x x为当前节点状态下的输入, h h h表示接收到上一个节点的输入;
y y y为当前节点状态下的输出, h ′ h' h′为传递到下一个节点的输出;
h ′ h' h′与 x x x和 h h h的值都相关;
y y y则常常使用 h ′ h' h′投入到一个线性层(主要是进行维度映射)然后使用softmax进行分类得到需要的数据。(往往看具体模型的使用方式)
1.2. LSTM
1.2.1. 前向forward
相比RNN只有一个传递状态 h t h^t ht,LSTM有两个传输状态,一个 s t s^t st(cell state),和一个 h t h^t ht(hidden state)。
s t , h t , y t = f ( s t − 1 , h t − 1 , x t ) s^t, h^t, y^t=f(s^{t-1}, h^{t-1}, x^t) st,ht,yt=f(st−1,ht−1,xt)
f t = σ ( W f x x t + W f h h t − 1 + b f ) f^t=σ(W_{fx}x^t+W_{fh}h^{t−1}+b_f) ft=σ(Wfxxt+Wfhht−1+bf)
i t = σ ( W i x x t + W i h h t − 1 + b i ) i^t=σ(W_{ix}x^t+W_{ih}h^{t−1}+b_i) it=σ(Wixxt+Wihht−1+bi)
o t = σ ( W o x x t + W o h h t − 1 + b o ) o^t=σ(W_{ox}x^t+W_{oh}h^{t−1}+b_o) ot=σ(Woxxt+Wohht−1+bo)
g t = ϕ ( W g x x t + W g h h t − 1 + b g ) g^t=ϕ(W_{gx}x^t+W_{gh}h^{t−1}+b_g) gt=ϕ(Wgxxt+Wghht−1+bg)
这部分在实际计算中,会将 x t x^t xt和 h t − 1 h^{t−1} ht−1组合起来 x c t = [ x t , h t − 1 ] x_c^t=[x^t,h^{t−1}] xct=[xt,ht−1],然后一起计算,故公式可以简写为如下形式:
f t = σ ( W f x c t + b f ) f^t=σ(W_{f}x_c^t+b_f) ft=σ(Wfxct+bf)
i t = σ ( W i x c t + b i ) i^t=σ(W_{i}x_c^t+b_i) it=σ(Wixct+bi)
o t = σ ( W o x c t + b o ) o^t=σ(W_{o}x_c^t+b_o) ot=σ(Woxct+bo)
g t = ϕ ( W g x c t + b g ) g^t=ϕ(W_{g}x_c^t+b_g) gt=ϕ(Wgxct+bg)
s t = g t ∗ i t + s t − 1 ∗ f t s^t=g^t∗i^t+s^{t−1}∗f^t st=gt∗it+st−1∗ft
h t = o t ∗ s t h^t=o^t*s^t ht=ot∗st
y ^ t = σ ( w ′ h t ) \hat{y}^t=\sigma(w'h^t) y^t=σ(w′ht)
Tips:
RNN中的 h t h^t ht对于LSTM中的 s t s^t st
s t s^t st改变得很慢,通常输出的 s t s^t st是上一个状态传过来的 s t − 1 s^{t-1} st−1加上一些数值;
h t h^t ht则主要依赖当前节点的数据,所以在不同节点往往会有很大的区别。
其中:
-
f f f表示forget,为忘记阶段。这个阶段主要是对上一个节点传进来的输入 s t − 1 s^{t−1} st−1进行选择性忘记 f t f^t ft;
-
i i i代表input,为选择记忆阶段。这个阶段将这个阶段的输入 g t g^t gt(对原始的 x t x^t xt进行了tanh激活)有选择性地进行记忆 i t i^t it;
(将上面两步得到的结果相加,即可得到传输给下一个状态的 s t s^t st)
-
o o o代表output,主要控制输出阶段,这个阶段将决定哪些将会被当成当前状态的输出,主要是通过 o t o^t ot来进行控制。(这里未对上一阶段得到的 s t s^t st进行放缩,有需要还可以通过一个tanh激活函数进行变化: h t = o t ∗ ϕ ( s t ) h^t=o^t*ϕ(s^t) ht=ot∗ϕ(st)
-
输出 y ^ t \hat{y}^t y^t与普通RNN类似,往往最终也是通过 h t h^t ht变化得到。这里假设 y ^ t \hat{y}^t y^t= h t h^t ht,则下面的 y ^ t \hat{y}^t y^t直接写为 h t h^t ht。
f f f, i i i, o o o都是门控(gate),使用 σ σ σ激活; g g g是作为输入数据的,不是门控状态,所以用 ϕ ϕ ϕ激活。
1.2.2. 损失函数lossFunc
定义每个时间步 t t t的损失函数为:
l ( t ) = f ( h ( t ) , y ( t ) ) l(t)=f(h(t),y(t)) l(t)=f(h(t),y(t))(1)
这里选用L2范数损失函数,也叫欧几里得损失函数,来计算loss,公式如下:
l ( t ) = f ( h ( t ) , y ( t ) ) = ∥ h ( t ) − y ( t ) ∥ 2 l(t)=f(h(t),y(t))=∥h(t)−y(t)∥^2 l(t)=f(h(t),y(t))=∥h(t)−y(t)∥2
最终目标是通过梯度下降来使整个长度为 T T T的序列的损失 L L L最小化:
L = ∑ t = 1 T l ( t ) L=\sum_{t=1}^Tl(t) L=t=1∑Tl(t)
1.2.3. 反向传播backpropagation
下面来推导loss梯度:
d L d w \frac{dL}{dw} dwdL
∵ w \because w ∵w是标量参数;且由(1)可知损失只与隐含层 h ( t ) h(t) h(t)和标签 y ( t ) y(t) y(t)有关;由链式法则
∴ d L d w = ∑ t = 1 T ∑ i = 1 M d L d h i ( t ) d h i ( t ) d w \therefore \frac{dL}{dw}=∑_{t=1}^T∑_{i=1}^M\frac{dL}{dh_i(t)}\frac{dh_i(t)}{dw} ∴dwdL=t=1∑Ti=1∑Mdhi(t)dLdwdhi(t)(2)
其中 h i ( t ) h_i(t) hi(t)是一个标量,是第 i i i个memory cell的隐含层的输出, M M M是memory cell的总数。在网络中信息会随着时间向前传播,在时间 t t t,改变 h i ( t ) h_i(t) hi(t)对 t t t之前的损失没有什么影响,所以公式可以写成如下:
d L d h i ( t ) = ∑ s = 1 T d l ( s ) d h i ( t ) = ∑ s = t T d l ( s ) d h i ( t ) \frac{dL}{dh_i(t)}=∑_{s=1}^T\frac{dl(s)}{dh_i(t)}=∑_{s=t}^T\frac{dl(s)}{dh_i(t)} dhi(t)dL=s=1∑Tdhi(t)dl(s)=s=t∑Tdhi(t)dl(s)(3)
为了方便,我们使用 L ( t ) L(t) L(t)来表示从 t t t开始的累计损失:
L ( t ) = ∑ s = t T l ( s ) L(t)=∑_{s=t}^{T}l(s) L(t)=s=t∑Tl(s)(4)
所以,当 t = 1 t=1 t=1时, L ( 1 ) L(1) L(1)则表示整个序列的损失。故(3)可以写为:
d
L
d
h
i
(
t
)
=
∑
s
=
t
T
d
l
(
s
)
d
h
i
(
t
)
=
d
L
(
t
)
d
h
i
(
t
)
\frac{dL}{dh_i(t)}=∑_{s=t}^T\frac{dl(s)}{dh_i(t)}=\frac{dL(t)}{dh_i(t)}
dhi(t)dL=s=t∑Tdhi(t)dl(s)=dhi(t)dL(t)
(2)可以写为:
d L d w = ∑ t = 1 T ∑ i = 1 M d L ( t ) d h i ( t ) d h i ( t ) d w \frac{dL}{dw}=∑_{t=1}^T∑_{i=1}^M\frac{dL(t)}{dh_i(t)}\frac{dh_i(t)}{dw} dwdL=t=1∑Ti=1∑Mdhi(t)dL(t)dwdhi(t)
d h i ( t ) d w \frac{dh_i(t)}{dw} dwdhi(t)部分就按照前向传播的公式去推导,下面介绍如何计算 d L ( t ) d h i ( t ) \frac{dL(t)}{dh_i(t)} dhi(t)dL(t)部分。
由(4)可得:
L ( t ) = { l ( t ) + L ( t + 1 ) , i f t < T l ( t ) , i f t = T L(t)=\begin{cases} l(t)+L(t+1), & if\quad t<T\\ l(t), & if\quad t=T \end{cases} L(t)={l(t)+L(t+1),l(t),ift<Tift=T
由此,给定时间 t t t和一个LSTM节点的 h ( t ) h(t) h(t),可得:
d L ( t ) d h ( t ) = d l ( t ) d h ( t ) + d L ( t + 1 ) d h ( t ) \frac{dL(t)}{dh(t)}=\frac{dl(t)}{dh(t)}+\frac{dL(t+1)}{dh(t)} dh(t)dL(t)=dh(t)dl(t)+dh(t)dL(t+1)
其中,前半部分 d l ( t ) d h ( t ) \frac{dl(t)}{dh(t)} dh(t)dl(t)为 h ( t ) h(t) h(t)在时间 t t t的损失 l ( t ) l(t) l(t)的求导;后半部分则体现了LSTM的recurrent性质。我们需要下一节点的derivative来计算当前节点的derivative。最终我们可以从 d L ( T ) d h ( T ) = d l ( T ) d h ( T ) \frac{dL(T)}{dh(T)}=\frac{dl(T)}{dh(T)} dh(T)dL(T)=dh(T)dl(T)开始,计算每个时间节点 t = 1 , … , T t=1,\dots,T t=1,…,T,即 d L ( t ) d h ( t ) \frac{dL(t)}{dh(t)} dh(t)dL(t)。