一. 概述:
想要真正了解梯度爆炸和消失问题,必须手推反向传播,了解反向传播里梯度更新到底是怎么样更新的,所有问题都需要用数学表达式来说明,经过手推之后,便可分析出是什么原因导致的。本人就是在手推之后,才真正了解了这个问题发生的本质,所以本文以手推反向传播开始。
二. 手推反向传播:
以上图为例开始推起来,先说明几点,i1,i2是输入节点,h1,h2为隐藏层节点,o1,o2为输出层节点,除了输入层,其他两层的节点结构为下图所示:
举例说明,NETo1 为输出层的输入,也就是隐藏层的输出经过线性变换后的值, OUTo1 为经过激活函数sigmoid后的值;同理 NETh1 为隐藏层的输入,也就是输入层经过线性变换后的值, OUTh1 为经过激活函数sigmoid 的值。只有这两层有激活函数,输入层没有。
定义一下sigmoid的函数: σ(z)=11+e−z
说一下sigmoid的求导:
σ′(z)=(11+e−z)′=e−z(1+e−z)2=1+e−z−1(1+e−z)2=σ(z)(1+e−z)2=σ(z)(1−σ(z))
定义一下损失函数,这里的损失函数是均方误差函数,即:
Losstotal=∑12(target - output)2
具体到上图,就是:
Losstotal=12(target1 - out_o1)2+12(target2 - out_o2)2
到这里,所有前提就交代清楚了,前向传播就不推了,默认大家都会,下面推反向传播。
- 第一个反向传播(热身)
先来一个简单的热热身,求一下损失函数对W5的偏导,即: ∂Losstotal∂w5
首先根据链式求导法则写出对W5求偏导的总公式,再把图拿下来对照(如上),可以看出,需要计算三部分的求导,下面就一步一步来:
总公式:总公式:∂Losstotal∂w5=∂Losstotal∂outo1∂outo1∂neto1∂neto1∂w5
第一步:第一步:∂Losstotal∂outo1=∂12(target1−outo1)2+12(target2−outo2)2∂outo1=outo1−target1
第二步:第二步:∂outo1∂neto1=∂11+e−neto1∂neto1=σ(neto1)(1−σ(neto1))
第三步:第三步:∂neto1∂w5=∂outh1w5+outh2w6∂w5=outh1
综上三个步骤,得到总公式:
总公式:总公式:∂Losstotal∂w5=(outo1−target1)⋅(σ(neto1)(1−σ(neto1)))⋅outh1
- 第二个反向传播:
接下来,要求损失函数对w1的偏导,即: ∂Losstotal∂w1
还是把图摆在这,方便看,先写出总公式,对w1求导有个地方要注意,w1的影响不仅来自o1还来自o2,从图上可以一目了然,所以总公式为:
总公式:总公式:∂Losstotal∂w1=∂Losstotal∂outo1∂outo1∂neto1∂neto1∂outh1∂outh1∂neth1∂neth1∂w1+∂Losstotal∂outo2∂outo2∂neto2∂neto2∂outh1∂outh1∂neth1∂neth1∂w1
所以总共分为左右两个式子,分别又对应5个步骤,详细写一下左边,右边同理:
第一步:第一步:∂Losstotal∂outo1=outo1−target1第二步:第二步:∂outo1∂neto1=σ(neto1)(1−σ(neto1))
第三步:第三步:∂neto1∂outh1=∂outh1w5+outh2w6∂outh1=w5
第四步:第四步:∂outh1∂neth1=σ(neth1)(1−σ(neth1))
第二步:第二步:∂neth1∂w1=∂i1w1+i2w2∂w1=i1
右边也是同理,就不详细写了,写一下总的公式:
∂Losstotal∂w1=((outo1−target1)⋅(σ(neto1)(1−σ(neto1)))⋅w5⋅(σ(neth1)(1−σ(neth1)))⋅i1)+((outo2−target2)⋅(σ(neto2)(1−σ(neto2)))⋅w7⋅(σ(neth1)(1−σ(neth1)))⋅i1)
这个公式只是对如此简单的一个网络结构的一个节点的偏导,就这么复杂。。亲自推完才深深的意识到。。。
为了后面描述方便,把上面的公式化简一下, outo1−target1 记为 Co1 , σ(neto1)(1−σ(neto1)) 记为 σ(neto1)′ ,则:
∂Losstotal∂w1=Co1⋅σ(neto1)′⋅w5⋅σ(neth1)′⋅i1+Co2⋅σ(neto2)′⋅w7⋅σ(neth1)′⋅i1
三. 梯度消失,爆炸产生原因:
从上式其实已经能看出来,求和操作其实不影响,主要是是看乘法操作就可以说明问题,可以看出,损失函数对w1的偏导,与 Co1 ,权重w,sigmoid的导数有关,明明还有输入i为什么不提?因为如果是多层神经网络的中间某层的某个节点,那么就没有输入什么事了。所以产生影响的就是刚刚提的三个因素。
再详细点描述,如图,多层神经网络:
参考:
PENG:神经网络训练中的梯度消失与梯度爆炸291 赞同 · 26 评论文章正在上传…重新上传取消
假设(假设每一层只有一个神经元且对于每一层 yi=σ(zi)=σ(wixi+bi),其中σ为sigmoid函数),如图:
则:
∂C∂b1=∂C∂y4∂y4∂z4∂z4∂x4∂x4∂z3∂z3∂x3∂x3∂z2∂z2∂x2∂x2∂z1∂z1∂b1=Cy4σ′(z4)w4σ′(z3)w3σ′(z2)w2σ′(z1)
看一下sigmoid函数的求导之后的样子:
发现sigmoid函数求导后最大最大也只能是0.25。
再来看W,一般我们初始化权重参数W时,通常都小于1,用的最多的应该是0,1正态分布吧。
所以 |σ′(z)w|≤0.25 ,多个小于1的数连乘之后,那将会越来越小,导致靠近输入层的层的权重的偏导几乎为0,也就是说几乎不更新,这就是梯度消失的根本原因。
再来看看梯度爆炸的原因,也就是说如果 |σ′(z)w|≥1 时,连乘下来就会导致梯度过大,导致梯度更新幅度特别大,可能会溢出,导致模型无法收敛。sigmoid的函数是不可能大于1了,上图看的很清楚,那只能是w了,这也就是经常看到别人博客里的一句话,初始权重过大,一直不理解为啥。。现在明白了。
但梯度爆炸的情况一般不会发生,对于sigmoid函数来说, σ(z)′ 的大小也与w有关,因为 z=wx+b ,除非该层的输入值x在一直一个比较小的范围内。
其实梯度爆炸和梯度消失问题都是因为网络太深,网络权值更新不稳定造成的,本质上是因为梯度反向传播中的连乘效应。
所以,总结一下,为什么会发生梯度爆炸和消失:
本质上是因为神经网络的更新方法,梯度消失是因为反向传播过程中对梯度的求解会产生sigmoid导数和参数的连乘,sigmoid导数的最大值为0.25,权重一般初始都在0,1之间,乘积小于1,多层的话就会有多个小于1的值连乘,导致靠近输入层的梯度几乎为0,得不到更新。梯度爆炸是也是同样的原因,只是如果初始权重大于1,或者更大一些,多个大于1的值连乘,将会很大或溢出,导致梯度更新过大,模型无法收敛。
四. 梯度消失,爆炸解决方案:
参考:
DoubleV:详解深度学习中的梯度消失、爆炸原因及其解决方法1133 赞同 · 62 评论文章正在上传…重新上传取消
解决方案一(预训练加微调):
此方法来自Hinton在2006年发表的一篇论文,Hinton为了解决梯度的问题,提出采取无监督逐层训练方法,其基本思想是每次训练一层隐节点,训练时将上一层隐节点的输出作为输入,而本层隐节点的输出作为下一层隐节点的输入,此过程就是逐层“预训练”(pre-training);在预训练完成后,再对整个网络进行“微调”(fine-tunning)。Hinton在训练深度信念网络(Deep Belief Networks中,使用了这个方法,在各层预训练完成后,再利用BP算法对整个网络进行训练。此思想相当于是先寻找局部最优,然后整合起来寻找全局最优,此方法有一定的好处,但是目前应用的不是很多了。
解决方案二(梯度剪切、正则):
梯度剪切这个方案主要是针对梯度爆炸提出的,其思想是设置一个梯度剪切阈值,然后更新梯度的时候,如果梯度超过这个阈值,那么就将其强制限制在这个范围之内。这可以防止梯度爆炸。
正则化是通过对网络权重做正则限制过拟合,仔细看正则项在损失函数的形式:
Loss=(y−WTx)2+α||W||2
其中, α 是指正则项系数,因此,如果发生梯度爆炸,权值的范数就会变的非常大,通过正则化项,可以部分限制梯度爆炸的发生。
注:事实上,在深度神经网络中,往往是梯度消失出现的更多一些
解决方案三(改变激活函数):
首先说明一点,tanh激活函数不能有效的改善这个问题,先来看tanh的形式:
tanh(x)=ex−e−xex+e−x
再来看tanh的导数图像:
发现虽然比sigmoid的好一点,sigmoid的最大值小于0.25,tanh的最大值小于1,但仍是小于1的,所以并不能解决这个
Relu:思想也很简单,如果激活函数的导数为1,那么就不存在梯度消失爆炸的问题了,每层的网络都可以得到相同的更新速度,relu就这样应运而生。先看一下relu的数学表达式:
Relu(x)=max(x,0)={0,x<0x,x>0}
从上图中,我们可以很容易看出,relu函数的导数在正数部分是恒等于1的,因此在深层网络中使用relu激活函数就不会导致梯度消失和爆炸的问题。
relu的主要贡献在于:
- 解决了梯度消失、爆炸的问题
- 计算方便,计算速度快
- 加速了网络的训练
同时也存在一些缺点:
- 由于负数部分恒为0,会导致一些神经元无法激活(可通过设置小学习率部分解决)
- 输出不是以0为中心的
leakrelu
leakrelu就是为了解决relu的0区间带来的影响,其数学表达为: leakrelu=f(x)={x,x>0x∗k,x≤0 其中k是leak系数,一般选择0.1或者0.2,或者通过学习而来解决死神经元的问题。
leakrelu解决了0区间带来的影响,而且包含了relu的所有优点
elu
elu激活函数也是为了解决relu的0区间带来的影响,其数学表达为:
{x, if x>0α(ex−1), otherwise
其函数及其导数数学形式为:
但是elu相对于leakrelu来说,计算要更耗时间一些,因为有e。
解决方案四(batchnorm):
Batchnorm是深度学习发展以来提出的最重要的成果之一了,目前已经被广泛的应用到了各大网络中,具有加速网络收敛速度,提升训练稳定性的效果,Batchnorm本质上是解决反向传播过程中的梯度问题。batchnorm全名是batch normalization,简称BN,即批规范化,通过规范化操作将输出信号x规范化到均值为0,方差为1保证网络的稳定性。
具体的batchnorm原理非常复杂,在这里不做详细展开,此部分大概讲一下batchnorm解决梯度的问题上。具体来说就是反向传播中,经过每一层的梯度会乘以该层的权重,举个简单例子: 正向传播中f3=f2(wT∗x+b),那么反向传播中,∂f2∂x=∂f2∂f1w,反向传播式子中有w的存在,所以w的大小影响了梯度的消失和爆炸,batchnorm就是通过对每一层的输出做scale和shift的方法,通过一定的规范化手段,把每层神经网络任意神经元这个输入值的分布强行拉回到接近均值为0方差为1的标准正太分布,即严重偏离的分布强制拉回比较标准的分布,这样使得激活输入值落在非线性函数对输入比较敏感的区域,这样输入的小变化就会导致损失函数较大的变化,使得让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度。
解决方案五(残差结构):
如图,把输入加入到某层中,这样求导时,总会有个1在,这样就不会梯度消失了。
∂loss∂xl=∂loss∂xL⋅∂xL∂xl=∂loss∂xL⋅(1+∂∂xL∑i=lL−1F(xi,Wi))
式子的第一个因子 ∂loss∂xL 表示的损失函数到达 L 的梯度,小括号中的1表明短路机制可以无损地传播梯度,而另外一项残差梯度则需要经过带有weights的层,梯度不是直接传递过来的。残差梯度不会那么巧全为-1,而且就算其比较小,有1的存在也不会导致梯度消失。所以残差学习会更容易。
注:上面的推导并不是严格的证
,只为帮助理解
解决方案六(LSTM):
在介绍这个方案之前,有必要来推导一下RNN的反向传播,因为关于梯度消失的含义它跟DNN不一样!不一样!不一样!
先推导再来说,从这copy的:
沉默中的思索:RNN梯度消失和爆炸的原因594 赞同 · 78 评论文章正在上传…重新上传取消
RNN结构如图:
假设我们的时间序列只有三段, S0 为给定值,神经元没有激活函数,则RNN最简单的前向传播过程如下:
S1=WxX1+WsS0+b1O1=WoS1+b2
S2=WxX2+WsS1+b1O2=WoS2+b2
S3=WxX3+WsS2+b1O3=WoS3+b2
假设在t=3时刻,损失函数为 L3=12(Y3−O3)2 。
则对于一次训练任务的损失函数为 L=∑t=0TLt ,即每一时刻损失值的累加。
使用随机梯度下降法训练RNN其实就是对 Wx 、 Ws 、 Wo 以及 b1b2 求偏导,并不断调整它们以使L尽可能达到最小的过程。
现在假设我们我们的时间序列只有三段,t1,t2,t3。
我们只对t3时刻的 、、Wx、Ws、W0 求偏导(其他时刻类似):
∂L3∂W0=∂L3∂O3∂O3∂Wo
∂L3∂Wx=∂L3∂O3∂O3∂S3∂S3∂Wx+∂L3∂O3∂O3∂S3∂S3∂S2∂S2∂Wx+∂L3∂O3∂O3∂S3∂S3∂S2∂S2∂S1∂S1∂Wx
∂L3∂Ws=∂L3∂O3∂O3∂S3∂S3∂Ws+∂L3∂O3∂O3∂S3∂S3∂S2∂S2∂Ws+∂L3∂O3∂O3∂S3∂S3∂S2∂S2∂S1∂S1∂Ws
可以看出对于 W0 求偏导并没有长期依赖,但是对于 、Wx、Ws 求偏导,会随着时间序列产生长期依赖。因为 St 随着时间序列向前传播,而 St 又是 、Wx、Ws的函数。
根据上述求偏导的过程,我们可以得出任意时刻对 、Wx、Ws 求偏导的公式:
∂Lt∂Wx=∑k=0t∂Lt∂Ot∂Ot∂St(∏j=k+1t∂Sj∂Sj−1)∂Sk∂Wx
任意时刻对Ws 求偏导的公式同上。
如果加上激活函数, Sj=tanh(WxXj+WsSj−1+b1) ,
则 ∏j=k+1t∂Sj∂Sj−1 = ∏j=k+1ttanh′Ws
激活函数tanh和它的导数图像在上面已经说过了,所以原因在这就不赘述了,还是一样的,激活函数导数小于1。
现在来解释一下,为什么说RNN和DNN的梯度消失问题含义不一样?
- 先来说DNN中的反向传播:在上文的DNN反向传播中,我推导了两个权重的梯度,第一个梯度是直接连接着输出层的梯度,求解起来并没有梯度消失或爆炸的问题,因为它没有连乘,只需要计算一步。第二个梯度出现了连乘,也就是说越靠近输入层的权重,梯度消失或爆炸的问题越严重,可能就会消失会爆炸。一句话总结一下,DNN中各个权重的梯度是独立的,该消失的就会消失,不会消失的就不会消失。
- 再来说RNN:RNN的特殊性在于,它的权重是共享的。抛开W_o不谈,因为它在某时刻的梯度不会出现问题(某时刻并不依赖于前面的时刻),但是W_s和W_x就不一样了,每一时刻都由前面所有时刻共同决定,是一个相加的过程,这样的话就有个问题,当距离长了,计算最前面的导数时,最前面的导数就会消失或爆炸,但当前时刻整体的梯度并不会消失,因为它是求和的过程,当下的梯度总会在,只是前面的梯度没了,但是更新时,由于权值共享,所以整体的梯度还是会更新,通常人们所说的梯度消失就是指的这个,指的是当下梯度更新时,用不到前面的信息了,因为距离长了,前面的梯度就会消失,也就是没有前面的信息了,但要知道,整体的梯度并不会消失,因为当下的梯度还在,并没有消失。
- 一句话概括:RNN的梯度不会消失,RNN的梯度消失指的是当下梯度用不到前面的梯度了,但DNN靠近输入的权重的梯度是真的会消失。
说完了RNN的反向传播及梯度消失的含义,终于该说为什么LSTM可以解决这个问题了,这里默认大家都懂LSTM的结构,对结构不做过多的描述。