问题引入
其实这算是个经典的问题了,在一般的只要你做过时间序列相关的项目或者理论的时候,LSTM和RNN的对比肯定是要问的。那两者到底有啥区别呢?
问题回答
其实对于这个问题,要从RNN在发展过程中带来的令人诟病的短处说起,RNN在train参数的时候,需要反向传播梯度,这个梯度是这么算的:
w
i
+
1
=
w
i
−
r
⋅
∂
L
o
s
s
∂
w
∣
w
:
w
i
,
r
>
0
w^{i+1}=w^{i}-r\cdot\frac{\partial{Loss }}{\partial{w}}|_{w:w^{i}},r>0
wi+1=wi−r⋅∂w∂Loss∣w:wi,r>0
其中
r
r
r是学习率,
∂
L
o
s
s
∂
w
∣
w
:
w
i
\frac{\partial{Loss }}{\partial{w}}|_{w:w^{i}}
∂w∂Loss∣w:wi是损失函数在w处的导数,针对RNN在结构上很深的特征,会产生梯度消失和梯度爆炸,其中需要了解下什么是梯度消失和梯度爆炸,梯度消失指的是,RNN在某些
w
i
w^i
wi取值上,导致梯度很小,梯度爆炸指的是,
w
i
w^i
wi在某些取值上,导致梯度特别大。如果你的学习率
r
r
r不变的话,那么参数要么几乎不变,要么就是变化剧烈,到时迭代动荡很难手收敛。通过我们对RNN的网络结构的建模,我们发现他的梯度是这个样子的:
∂
L
t
∂
W
h
=
∑
t
=
0
T
∑
k
=
0
t
∂
L
t
∂
y
t
∂
y
t
∂
h
t
(
∏
j
=
k
+
1
t
∂
h
j
∂
h
j
−
1
)
∂
h
k
∂
W
h
\frac{\partial{L_{t}}}{\partial{W^{h}}}=\sum_{t=0}^{T}{\sum_{k=0}^{t}{ \frac{\partial{L_{t}}}{\partial{y_t}} \frac{\partial{y_{t}}}{\partial{h_t}} (\prod_{j=k+1}^{t} \frac{\partial{h_{j}}}{\partial{h_{j-1}}} ) \frac{\partial{h_{k}}}{\partial{W^h}}}}
∂Wh∂Lt=t=0∑Tk=0∑t∂yt∂Lt∂ht∂yt(j=k+1∏t∂hj−1∂hj)∂Wh∂hk
我们先不管这一大串公式是啥意思,大值得意思就是上面公式里面有依赖于时间
t
t
t的连乘符号;修正
t
t
t时刻的误差需要考虑之前的所有时间
k
k
k的隐藏层对时间
t
t
t的影响,当
k
k
k和
t
t
t距离越远,对应着隐含层之间的连乘次数就越多。就是这个连乘的结构产生了梯度消失,梯度爆炸也是它导致的。具体大一大波公式有需要看的话可以看下参考中的地(我只是搬运工)。
而LSTM(长短时记忆网络),因为可以通过阀门(gate,其实就是概率,共有输出、遗忘、输入三个阀门)记忆一些长期信息,所以,相比RNN,保留了更多长期信息(相应地也就保留了更多的梯度)。隐层之间的输入输出可以表示为:
c
j
=
σ
(
W
f
X
j
+
b
f
)
c
j
−
1
+
σ
(
W
i
X
j
+
b
i
)
σ
(
W
X
j
+
b
)
c_{j}=\sigma(W^fX_{j}+b^f)c_{j-1}+\sigma({W^iX_{j}}+b^i)\sigma(WX_{j}+b)
cj=σ(WfXj+bf)cj−1+σ(WiXj+bi)σ(WXj+b),于是连乘的项可以表示为:
∂
c
j
∂
c
j
−
1
=
σ
(
W
f
X
j
+
b
)
\frac{\partial{c_{j}}}{\partial{c_{j-1}}}=\sigma(W^fX_{j}+b)
∂cj−1∂cj=σ(WfXj+b)
该值得范围在0-1之间,在参数更新的过程中,可以通过控制bais较大来控制梯度保持在1,即使通过多次的连乘操作,梯度也不会下降到消失的状态。所以,相比RNN,在LSTM上,梯度消失问题得到了一定程度的缓解。
更多内容,查看如下 机器学习算法面试:
https://www.zhihu.com/question/44895610/answer/616818627
https://zhuanlan.zhihu.com/p/30844905
https://blog.youkuaiyun.com/laolu1573/article/details/77470889