习题6-4 推导LSTM网络中参数的梯度, 并分析其避免梯度消失的效果
首先需要明确的是,RNN 中的梯度消失/梯度爆炸和普通的 MLP 或者深层 CNN 中梯度消失/梯度爆炸的含义不一样。MLP/CNN 中不同的层有不同的参数,各是各的梯度;而 RNN 中同样的权重在各个时间步共享,最终的梯度为各个时间步的梯之和。
因此,RNN 中总的梯度是不会消失的。即便梯度越传越弱,那也只是远距离的梯度消失,由于近距离的梯度不会消失,所有梯度之和便不会消失。RNN 所谓梯度消失的真正含义是,梯度被近距离梯度主导,导致模型难以学到远距离的依赖关系。
有多条求导路径,最后将这些求导路径相加得到最终的梯度,只要保证有一条远距离路径梯度不消失,总的远距离梯度就不会消失,这里我们只关注传到
的过程
下图箭头为为从传到
的过程的相反方向,即反向传播的方向,值得注意的是,从
传到
,包括从
传到
,再从
传到
。

可以求得公式

可以看到,这个梯度是由四个式子相加而成,其他姑且不用看,就看这个,控制好
就能让它梯度不消失。
也就是说,如果想让过去的影响现在的输出,让为1就可以了,之后在上图最上面那条路径中梯度会直接从
流到
,梯度不会消失,
的大小由训练学习得到的(它会学习到什么时候就不记前头的了),直到有输入x得到的
为0,则和前面的信息就没有关系了。
就是反向传到较远的地方,它的梯度是有多个式子的和(多条路到达),而不是仅仅是几个变量相乘的单一式子。
注意:梯度消失现象可以改善,但是梯度爆炸还是可能会出现的。
LSTM依然不能完全解决梯度消失这个问题,有文献表示序列长度一般到了三百多仍然会出现梯度消失现象。如果想彻底规避这个问题,还是transformer好用。(现在还没学 transformer,浅期待一下)
参考老师给的鱼书上的看法分析,避免梯度消失主要有两个原因,一个是使用对应元素的乘积,一个是遗忘门的使用。

习题6-3P 编程实现下图LSTM运行过程
同学提出,未发现输入。可以适当改动例题,增加该输入。
实现LSTM算子,可参考实验教材代码。

文章探讨了LSTM网络中参数梯度的计算,解释了RNN中的梯度消失问题并非真正的消失,而是由权重共享导致的梯度减弱。重点介绍了遗忘门的作用以及如何通过控制权重来避免远距离梯度消失。同时,文章提到了LSTM无法完全解决梯度消失,Transformer在此方面表现更好。
最低0.47元/天 解锁文章
1277





