前言
梯度爆炸和梯度消失问题都是因为网络太深,网络权值更新不稳定造成的,本质上是因为梯度反向传播中的连乘效应。
前向传播:
z1=w1X+b1,a1=σ(z1)z2=w2a1+b2,a2=σ(z2)...zn=wnan−1+bn,an=σ(zn)
\begin{aligned}
z_1&=w_1X+b_1,a_1=\sigma (z_1)\\
z_2&=w_2a_1+b_2,a_2=\sigma(z_2)\\
...\\
z_n&=w_na_{n-1+b_n},a_n=\sigma(z_n)\\
\end{aligned}
z1z2...zn=w1X+b1,a1=σ(z1)=w2a1+b2,a2=σ(z2)=wnan−1+bn,an=σ(zn)
则反向传播:
αlossαw1=αlossαanαanαznαznαan−1αan−1αzn−1αzn−1αan−2αan−2αzn−2...αa1αz1αz1αw1=αlosssαan⋅σ′(zn)wn⋅σ′(zn−1)wn−1⋅...⋅σ′(z1)X
\begin{aligned}
\frac{\alpha loss}{\alpha w_1}
&=\frac{\alpha loss}{\alpha a_n}\frac{\alpha a_n}{\alpha z_n}\frac{\alpha z_n}{\alpha a_{n-1}}\frac{\alpha a_{n-1}}{\alpha z_{n-1}}\frac{\alpha z_{n-1}}{\alpha a_{n-2}}\frac{\alpha a_{n-2}}{\alpha z_{n-2}}...\frac{\alpha a_1}{\alpha z_1}\frac{\alpha z_1}{\alpha w_1}\\
&=\frac{\alpha losss}{\alpha a_n}·\sigma'(z_n)w_n·\sigma'(z_{n-1})w_{n-1}·...·\sigma'(z_1)X
\end{aligned}
αw1αloss=αanαlossαznαanαan−1αznαzn−1αan−1αan−2αzn−1αzn−2αan−2...αz1αa1αw1αz1=αanαlosss⋅σ′(zn)wn⋅σ′(zn−1)wn−1⋅...⋅σ′(z1)X
-
梯度消失:与激活函数的导数σ′(x)\sigma^{'}(x)σ′(x)有关。
假如σ\sigmaσ为sigmoid激活函数,而sigmoid的导数范围是[0,0.25],"链式法则"的累乘会导致梯度趋于0. -
梯度爆炸:与权重有关,即∣σ′(z)w∣>1|\sigma'(z) w|>1∣σ′(z)w∣>1。
链式法则还与∣σ′(z)w∣|\sigma'(z) w|∣σ′(z)w∣有关,如果该值>1,"链式法则"累乘后会导致梯度趋于非常大的值.
梯度消失
与梯度太小有关。表现为只在后层学习,浅层不学习,浅层梯度基本无,权重改变量小,收敛慢,训练速度慢。
原因:
- 采用了不适合的激活函数,导致链式法则累乘时被0影响。
- 模型在训练的过程中,会不断调整数据分布,有可能接近激活函数饱和区,此时的导数很小,难以调整权重。
解决办法:
- 使用BN,将数据分布归一化。
- 预训练,微调。
- 使用relu等激活函数。
- 使用残差结构。
- LSTM。
- 正则化。
梯度爆炸
与链式法则中的权重有关。可能导致权重NAN。
原因:
- 若初始化权重太大,累乘后会爆炸。
- 梯度>1。
解决办法:
- 注意权重初始化。
- 梯度剪裁。
- BN。
- 预训练,微调。
RNN为何会梯度消失/爆炸?
首先看RNN计算流程,简设3个timestep:
前向传播:
S1=WxX1+WsS0+b1S_1=W_xX_1+W_sS_0+b1S1=WxX1+WsS0+b1, O1=WoS1+b2O_1=W_oS_1+b2O1=WoS1+b2。
S2=WxX2+WsS1+b1S_2=W_xX_2+W_sS_1+b1S2=WxX2+WsS1+b1, O2=WoS2+b2O_2=W_oS_2+b2O2=WoS2+b2。
S3=WxX3+WsS2+b1S_3=W_xX_3+W_sS_2+b1S3=WxX3+WsS2+b1, O3=WoS3+b2O_3=W_oS_3+b2O3=WoS3+b2。
此刻的损失函数:loss3=12(Y3−O3)2loss_3=\frac{1}{2}(Y_3-O_3)^2loss3=21(Y3−O3)2。
反向传播:
需要对WoW_oWo, WsW_sWs, WxW_xWx求导,其中对WsW_sWs和WxW_xWx求导是同理的。
(1) δloss3δWo=δloss3δO3δO3δWo\frac{\delta loss_3}{\delta W_o}=\frac{\delta loss_3}{\delta O_3}\frac{\delta O_3}{\delta W_o}δWoδloss3=δO3δloss3δWoδO3
可以看出网络加深对于WoW_oWo无影响。
(2) δloss3δWs=δloss3δO3δO3δS3δS3δWs+δloss3δO3δO3δS3δS3δS2δS2δWs+δloss3δO3δO3δS3δS3δS2δS2δS1δS1δWs\frac{\delta loss_3}{\delta W_s}=\frac{\delta loss_3}{\delta O_3}\frac{\delta O_3}{\delta S_3}\frac{\delta S_3}{\delta W_s}+\frac{\delta loss_3}{\delta O_3}\frac{\delta O_3}{\delta S_3}\frac{\delta S_3}{\delta S_2}\frac{\delta S_2}{\delta W_s}+\frac{\delta loss_3}{\delta O_3}\frac{\delta O_3}{\delta S_3}\frac{\delta S_3}{\delta S_2}\frac{\delta S_2}{\delta S_1}\frac{\delta S_1}{\delta W_s}δWsδloss3=δO3δloss3δS3δO3δWsδS3+δO3δloss3δS3δO3δS2δS3δWsδS2+δO3δloss3δS3δO3δS2δS3δS1δS2δWsδS1。
可以简写为:
δlosstδWs=∑k=0tδlosstδOtδOtδSt∏j=k+1t(δSjδSj−1)δSkδWx\frac{\delta loss_t}{\delta W_s}=\sum_{k=0}^t\frac{\delta loss_t}{\delta O_t}\frac{\delta O_t}{\delta S_t}\prod_{j=k+1}^t(\frac{\delta S_j}{\delta S_{j-1}})\frac{\delta S_k}{\delta W_x}δWsδlosst=∑k=0tδOtδlosstδStδOt∏j=k+1t(δSj−1δSj)δWxδSk。
其中连乘的∏j=k+1t(δSjδSj−1)\prod_{j=k+1}^t(\frac{\delta S_j}{\delta S_{j-1}})∏j=k+1t(δSj−1δSj)是导致梯度爆炸和消失的问题所在。
RNN梯度与其他网络梯度的区别
- MLP/CNN 中不同的层有不同的参数,各是各的梯度;而 RNN 中同样的权重在各个时间步共享,最终的梯度= 各个时间步的梯度的和。
- RNN 中总的梯度是不会消失的。即便梯度越传越弱,那也只是远距离的梯度消失,由于近距离的梯度不会消失,所有梯度之和便不会消失。RNN 所谓梯度消失的真正含义是,梯度被近距离梯度主导,导致模型难以学到远距离的依赖关系。
LSTM如何缓解梯度消失/爆炸?
LSTM介绍
- 遗忘门
- 可求得:ft=σ(Wf⋅[ht−1,xt]+bf)f_t=\sigma (W_f·[h_{t-1},x_t]+b_f)ft=σ(Wf⋅[ht−1,xt]+bf).
- 可求得:ft=σ(Wf⋅[ht−1,xt]+bf)f_t=\sigma (W_f·[h_{t-1},x_t]+b_f)ft=σ(Wf⋅[ht−1,xt]+bf).
- 输入门
可求得:- it=σ(Wi⋅[ht−1,xt]+bi)i_t=\sigma (W_i·[h_{t-1},x_t]+b_i)it=σ(Wi⋅[ht−1,xt]+bi).
- C^t=tanh(WC⋅[ht−1,xt]+bC)\hat C_t=tanh (W_C·[h_{t-1},x_t]+b_C)C^t=tanh(WC⋅[ht−1,xt]+bC).
- Ct=ft⋅Ct−1+it⋅C^tC_t=f_t·C_{t-1}+i_t·\hat C_tCt=ft⋅Ct−1+it⋅C^t
- 输出门
可求得:- Ot=σ(Wo⋅[ht−1,xt]+bo)O_t=\sigma (W_o·[h_{t-1},x_t]+b_o)Ot=σ(Wo⋅[ht−1,xt]+bo).
- ht=Ot⋅tanh(Ct)h_t=O_t·tanh(C_t)ht=Ot⋅tanh(Ct).
- Ot=σ(Wo⋅[ht−1,xt]+bo)O_t=\sigma (W_o·[h_{t-1},x_t]+b_o)Ot=σ(Wo⋅[ht−1,xt]+bo).
LSTM可以解决梯度消失,缓解梯度爆炸
整理可得公式:
- ft=σ(Wf⋅[ht−1,xt]+bf)f_t=\sigma (W_f·[h_{t-1},x_t]+b_f)ft=σ(Wf⋅[ht−1,xt]+bf).
- it=σ(Wi⋅[ht−1,xt]+bi)i_t=\sigma (W_i·[h_{t-1},x_t]+b_i)it=σ(Wi⋅[ht−1,xt]+bi).
- C^t=tanh(Wc⋅[ht−1,xt]+bc)\hat C_t=tanh (W_c·[h_{t-1},x_t]+b_c)C^t=tanh(Wc⋅[ht−1,xt]+bc).
- Ct=ft⋅Ct−1+it⋅C^tC_t=f_t·C_{t-1}+i_t·\hat C_tCt=ft⋅Ct−1+it⋅C^t
- Ot=σ(Wo⋅[ht−1,xt]+bo)O_t=\sigma (W_o·[h_{t-1},x_t]+b_o)Ot=σ(Wo⋅[ht−1,xt]+bo).
- ht=Ot⋅tanh(Ct)h_t=O_t·tanh(C_t)ht=Ot⋅tanh(Ct).
-
LSTM 中梯度的传播有很多条路径,Ct−1→Ct=ft⋅ct−1+it⋅c^tC_{t-1} \rightarrow C_t=f_t·c_{t-1}+i_t·\hat c_tCt−1→Ct=ft⋅ct−1+it⋅c^t这条路径上只有逐元素相乘和相加的操作,梯度流最稳定;但是其他路径(例如Ct−1→ht−1→it→ctC_{t-1} \rightarrow h_{t-1} \rightarrow i_t \rightarrow c_tCt−1→ht−1→it→ct)上梯度流与普通 RNN 类似,照样会发生相同的权重矩阵反复连乘。根据上式可以看出CtC_tCt公式与hth_tht, iti_tit, C^t\hat C_tC^t, Ct−1C_{t-1}Ct−1有关,则可以得出:
δC(k)δC(k−1)=δC(k)δf(k)δf(k)δh(k−1)δh(k−1)δC(k−1)[ht公式]+δC(k)δi(k)δi(k)δh(k−1)δh(k−1)δC(k−1)[it公式]+δC(k)δC^(k)δC^(k)δh(k−1)δh(k−1)δC(k−1)[C^t公式]+δC(k)δC(k−1)[Ct公式]=Ct−1(σ′⋅Wf)(ot⋅tanh′)+C^t(σ′⋅Wi)(ot⋅tanh′)+it(tanh′⋅Wc)(ot⋅tanh′)+ft\begin{aligned} \frac{\delta C^{(k)}}{\delta C^{(k-1)}} &=\frac{\delta C^{(k)}}{\delta f^{(k)}}\frac{\delta f^{(k)}}{\delta h^{(k-1)}}\frac{\delta h^{(k-1)}}{\delta C^{(k-1)}}[h_t公式]\\ &+\frac{\delta C^{(k)}}{\delta i^{(k)}}\frac{\delta i^{(k)}}{\delta h^{(k-1)}}\frac{\delta h^{(k-1)}}{\delta C^{(k-1)}}[i_t公式]\\ &+\frac{\delta C^{(k)}}{\delta \hat C^{(k)}}\frac{\delta \hat C^{(k)}}{\delta h^{(k-1)}}\frac{\delta h^{(k-1)}}{\delta C^{(k-1)}}[\hat C_t公式]\\ &+\frac{\delta C^{(k)}}{\delta C^{(k-1)}} [C_t公式]\\ &=C^{t-1}(\sigma'·W_f)(o^t·tanh')\\ &+\hat C^{t}(\sigma'·W_i)(o^t·tanh')\\ &+i^{t}(tanh'·W_c)(o^t·tanh')\\ &+f_t \end{aligned} δC(k−1)δC(k)=δf(k)δC(k)δh(k−1)δf(k)δC(k−1)δh(k−1)[ht公式]+δi(k)δC(k)δh(k−1)δi(k)δC(k−1)δh(k−1)[it公式]+δC^(k)δC(k)δh(k−1)δC^(k)δC(k−1)δh(k−1)[C^t公式]+δC(k−1)δC(k)[Ct公式]=Ct−1(σ′⋅Wf)(ot⋅tanh′)+C^t(σ′⋅Wi)(ot⋅tanh′)+it(tanh′⋅Wc)(ot⋅tanh′)+ft
因此RNN的问题∏j=kt\prod_{j=k}^t∏j=kt在LSTM中等价于(fk⋅fk+1⋅f2⋅...⋅ft)+other(f^{k}·f^{k+1}·f^{2}·...·f^{t})+other(fk⋅fk+1⋅f2⋅...⋅ft)+other -
正常梯度 + 消失梯度 = 正常梯度,总的远距离梯度就不会消失,因此 LSTM 可以解决梯度消失。
- 可自主选择[0,1]之间,当遗忘门接近 1时(例如模型初始化时会把 forget bias 设置成较大的正数,让遗忘门饱和),这时候远距离梯度不消失;
- 当遗忘门接近 0时,但这时模型是故意阻断梯度流的(例如情感分析任务中有一条样本 “A,但是 B”,模型读到“但是”后选择把遗忘门设置成 0,遗忘掉内容 A,这是合理的)。
-
正常梯度 + 爆炸梯度 = 爆炸梯度,因此 LSTM 仍然有可能发生梯度爆炸。不过,由于 LSTM 和普通 RNN 相比多经过了很多次激活函数(导数都小于 1),因此 LSTM 发生梯度爆炸的频率要低得多。
参考
https://zhuanlan.zhihu.com/p/25631496
https://www.cnblogs.com/bonelee/p/10475453.html
https://www.zhihu.com/question/34878706/answer/665429718