前言
由于个人理解能力有限,我看了好几遍李宏毅老师的推导才大致理解,如有错误之处恳请指正。
梯度与传播的关系
前面已经使用泰勒展开推导过损失函数值沿变量梯度的反方向下降最快的结论,考虑如下的网络结构
其中xnx_nxn代表输入的特征,wnw_nwn代表权重,bbb代表偏置,z=wx1+wx2+bz=wx_1+wx_2+bz=wx1+wx2+b,zzz作为激活函数σ(z)\sigma(z)σ(z)的参数。
此时的目的是得到损失函数LLL 对每个wiw_iwi 的偏导(损失函数是针对整个模型来说的)。
LLL 对w1w_1w1 的偏导并不那么好求,首先,LLL 是所有样本损失函数(也就是所有样本交叉熵CCC)的集合,因此这里考虑单个样本交叉熵CCC 对所有wiw_iwi 的偏导。
拿w1w_1w1 举例:从后往前看,lll 包含σ(z′)\sigma(z^{\prime})σ(z′)和σ(z′′)\sigma(z^{\prime\prime})σ(z′′),σ(z′)\sigma(z^{\prime})σ(z′)和σ(z′′)\sigma(z^{\prime\prime})σ(z′′)包含z′z^{\prime}z′和z′′z^{\prime\prime}z′′,z′z^{\prime}z′和z′′z^{\prime\prime}z′′包含σ(z)\sigma(z)σ(z),σ(z)\sigma(z)σ(z)包含zzz,zzz包含w1w_1w1…
感觉好复杂,不过也可以因此联想到高数学习的链式法则——g(f(x))g(f(x))g(f(x))对xxx积分等于f′(x)g′(f(x))f^{\prime}(x)g^{\prime}(f(x))f′(x)g′(f(x)).
链式法则(chain rule)
链式法则是反向传播算法的关键,通过链式法则,可以化繁为简,最终求得∂C∂w1\frac{\partial C}{\partial w_{1}}∂w1∂C,下图是链式法则的精髓所在:
反向传播的推导
假设有中间层如下
则可以根据链式法则求得∂C/∂w1\partial C/\partial w_{1}∂C/∂w1,从左向右(输入向输出)一点点看,首先:
∂C∂w1=∂C∂z∂z∂w1
\frac{\partial C}{\partial w_{1}}=\frac{\partial C}{\partial z} \frac{\partial z}{\partial w_{1}}
∂w1∂C=∂z∂C∂w1∂z
由于 z=w1x1+w2x2+bz=w_1x_1+w_2x_2+bz=w1x1+w2x2+b,因此 ∂z/∂w1=x1\partial z / \partial w_{1}=x_1∂z/∂w1=x1,下面求 ∂C/∂z\partial C/\partial z∂C/∂z:
∂C∂z=∂C∂a∂a∂z
\frac{\partial C}{\partial z}=\frac{\partial C}{\partial a} \frac{\partial a}{\partial z}
∂z∂C=∂a∂C∂z∂a
由于a=σ(z)a=\sigma(z)a=σ(z),因此∂σ(z)/∂z=σ′(z)\partial \sigma(z)/\partial z=\sigma^{\prime}(z)∂σ(z)/∂z=σ′(z),下面求∂C/∂a\partial C/\partial a∂C/∂a
回过头看链式法则的case 2,CCC 由aaa 发射出两条路汇成,因此:
∂C∂a=∂C∂z′∂z′∂a+∂C∂z′′∂z′′∂a
\frac{\partial C}{\partial a}=\frac{\partial C}{\partial z^{\prime}} \frac{\partial z^{\prime}}{\partial a}+\frac{\partial C}{\partial z^{\prime\prime}} \frac{\partial z^{\prime \prime}}{\partial a}
∂a∂C=∂z′∂C∂a∂z′+∂z′′∂C∂a∂z′′
而∂z′/∂a=w3\partial z^{\prime}/\partial a=w_3∂z′/∂a=w3,∂z′′/∂a=w4\partial z^{\prime\prime}/\partial a=w_4∂z′′/∂a=w4,因此上面可以变为:
∂C∂a=w3∂C∂z′+w4∂C∂z′′ \frac{\partial C}{\partial a}=w_3\frac{\partial C}{\partial z^{\prime}} +w_4\frac{\partial C}{\partial z^{\prime\prime}} ∂a∂C=w3∂z′∂C+w4∂z′′∂C
那么对于每一层,对照下图,便有:
∂C∂z=σ′(z)[w3∂C∂z′+w4∂C∂z′′]
\frac{\partial C}{\partial z}=\sigma^{\prime}(z)\left[w_{3} \frac{\partial C}{\partial z^{\prime}}+w_{4} \frac{\partial C}{\partial z^{\prime \prime}}\right]
∂z∂C=σ′(z)[w3∂z′∂C+w4∂z′′∂C]
如果z′z^{\prime}z′ 后面便对应着输出y1y_1y1,z′′z^{\prime\prime}z′′后面便对应着y2y_2y2,那么显然:
∂C∂z′=∂y1∂z′∂C∂y1且∂C∂z′′=∂y2∂z′′∂C∂y2
\frac{\partial C}{\partial z^{\prime}}=\frac{\partial y_{1}}{\partial z^{\prime}} \frac{\partial C}{\partial y_{1}} \quad且\quad \frac{\partial C}{\partial z^{\prime \prime}}=\frac{\partial y_{2}}{\partial z^{\prime \prime}} \frac{\partial C}{\partial y_{2}}
∂z′∂C=∂z′∂y1∂y1∂C且∂z′′∂C=∂z′′∂y2∂y2∂C
此时由于y1y_1y1 与 y2y_2y2 已知,很容易便可以算出上式。
如果a′a^{\prime}a′ 和a′′a^{\prime\prime}a′′ 后面还对应结点,则将a′a^{\prime}a′ 和a′′a^{\prime\prime}a′′ 作为输入,递归的进行上面的步骤直到输出层。
反向计算会使得很大程度上降低时间复杂度,对于∂C∂w=∂C∂z∂z∂w
\frac{\partial C}{\partial w_{}}=\frac{\partial C}{\partial z} \frac{\partial z}{\partial w_{}}
∂w∂C=∂z∂C∂w∂z可以前向计算每个结点的值作为∂z/∂w\partial z/\partial w∂z/∂w,通过反向传播计算∂C/∂z\partial C/\partial z∂C/∂z(计算公式在上面),二者相乘即为损失函数CCC 对所有www 进行偏微分的集合。