注:此篇博客主要有参考中的资料摘选集合而成,原创内容很少,如有困惑的地方可以详细阅读参考中的资料,写的真的很好,强烈推荐阅读。
一、实现长期依赖
理论上,各个门的值应该在 [0, 1] 之间(sigmoid)。但是如果真正训过一些表现良好的网络并且查看过门的值,就会发现很多时候门的值都是非常接近 0 或者 1 的,而类似于 0.2、0.5等 中间值很少。从直觉上我们希望门是二元的(0或1),控制信息流动的通和断,事实上训练良好的门也确实能达到这种效果。
1.选择性
想让信息流动的时候的就让它流动,不想让它流动的时候就关掉。例如做情感分析时,只让有情感极性的词和关联词等信息输入进网络,把别的忽略掉。这样一来,网络需要记忆的内容更少,自然也更容易记住。同样以 LSTM 为例,如果某个时刻 forget gate 是 0,虽然把网络的记忆清空了、回传的梯度也截断掉了,但这是 feature,不是 bug。这里举一个需要选择性的任务:给定一个序列,前面的字符都是英文字符,最后以三个下划线结束(例如 “abcdefg___”);要求模型每次读入一个字符,在读入英文字符时输出下划线,遇到下划线后输出它遇到的前三个字符(对上面的例子,输出应该是 “_______abc”)。显然,为了完成这个任务,模型需要学会记数(数到 3),只读入前三个英文字符,中间的字符都忽略掉,最终遇到 _ 时再输出它所记住的三个字符。“只读前三个字符”体现的就是选择性。
2.信息不变形
模型状态在跨时间步时不存在非线性变换,而是加性的。假如普通 RNN 的状态里存了某个信息,经过多个时间步以后多次非线性变换把信息变得面目全非了,即使这个信息模型仍然记得,但是也读取不出来它当时到底存的是什么了。而引入门机制的 RNN 单元在各个时间步是乘上一些 0/1 掩码再加新信息,没有非线性变换,这就导致网络想记住的内容(对应掩码为 1)过多个时间步记得还是很清楚。
二、缓解梯度消失
1.RNN为什么会梯度爆炸和消失
根据RNN的反向传播算法BPTT(Back Propagation Through Time)有
梯度消失主要就是针对上面(2)和(3)两个式子来讲的,其中的连乘项中当
t
t
t大于
j
j
j很多时(长距连接),
∏
j
=
k
+
1
t
∂
h
j
∂
h
j
−
1
\prod_{j=k+1}^{t} \frac{\partial h_{j}}{\partial h_{j-1}}
∏j=k+1t∂hj−1∂hj很容易近似于零或者非常大
h
t
=
σ
(
W
i
x
t
+
W
h
h
t
−
1
)
h_{t}= \sigma(W^{i}x_{t}+W^{h}h_{t-1})
ht=σ(Wixt+Whht−1)
∂
h
j
∂
h
j
−
1
=
σ
′
W
h
\frac{\partial h_{j}}{\partial h_{j-1}}= \sigma ^{'}W^{h}
∂hj−1∂hj=σ′Wh针对梯度消失来说有一个问题,即使“长距”连接部分的值很小,但这只是求和内容中的子项,公式中还存在大量的"短距"连乘项,这些项不是仍然可以组成梯度方向用于参数更新么?
可能的答案是在随机梯度下降的最开始时候,公式中的"短距"连乘项或许会产生一些梯度方向,但是随着随着参数的动态更新,这些"短距"连乘项构成的方向会引导Loss Function到局部最优的位置上去,而局部最优的地方就是梯度为0的地方,因此这些短距项也就趋向于零了。此时要达到更好的效果就要根据长距部分的梯度进行优化,而长距部分要是也趋近于零就意味着会发生梯度消失。并且长距梯度对优化不起作用也从说明了RNN没有长期依赖,本质上梯度消失还是没办法建立起长期依赖的问题
2.解决办法
对于记忆状态的路径来说
c
t
=
f
t
c
t
−
1
+
i
t
g
t
c_{t}=f_{t}c_{t-1}+i_{t}g{t}
ct=ftct−1+itgt,求导为
∂
c
t
∂
c
t
−
1
=
f
t
+
∂
f
t
∂
c
t
−
1
+
.
.
.
\frac{\partial c_{t}}{\partial c_{t-1}}=f_{t}+\frac{\partial f_{t}}{\partial c_{t-1}}+...
∂ct−1∂ct=ft+∂ct−1∂ft+...
不关心其他项,就第一项而言
f
t
f_{t}
ft为遗忘门的输出,结果在[0,1],但如前文所说更多的情况要么是零要么是1。在零的情况下,说明要忘记所有前面的信息,因此梯度不必回传。当为1时,自然也就不会发生梯度消失。
但是在其他路径上,LSTM 的梯度流和普通 RNN 没有太大区别,依然会爆炸或者消失。由于总的远距离梯度 = 各条路径的远距离梯度之和,即便其他远距离路径梯度消失了,只要保证有一条远距离路径(就是上面说的记忆状态路径)梯度不消失,总的远距离梯度就不会消失(正常梯度 + 消失梯度 = 正常梯度)。因此 LSTM 通过改善一条路径上的梯度问题拯救了总体的远距离梯度。
同样,因为总的远距离梯度 = 各条路径的远距离梯度之和,高速公路上梯度流比较稳定,但其他路径上梯度有可能爆炸,此时总的远距离梯度 = 正常梯度 + 爆炸梯度 = 爆炸梯度,因此 LSTM 仍然有可能发生梯度爆炸。不过,由于 LSTM 的其他路径非常崎岖,和普通 RNN 相比多经过了很多次激活函数(导数都小于 1),因此 LSTM 发生梯度爆炸的频率要低得多。实践中梯度爆炸一般通过梯度裁剪来解决。
总结:
对于现在常用的带遗忘门的 LSTM 来说,其一是遗忘门接近 1(可以在模型初始化时把 forget bias 设置成较大的正数,让遗忘门饱和),这时候远距离梯度不消失;其二是遗忘门接近 0,但这时模型是故意阻断梯度流的,这不是 bug 而是 feature。当然,也存在 f 介于 [0, 1] 之间的情况,在这种情况下只能说 LSTM 缓解而非解决了梯度消失的状况。
参考
RNN 中学习长期依赖的三种机制https://zhuanlan.zhihu.com/p/34490114
漫谈LSTM系列的梯度问题https://zhuanlan.zhihu.com/p/36101196