初读此书,对下图中这个公式无法理解。如果仿照作者的之前叙述的方法,需要进行矩阵对矩阵的求导,如何求导以及对于矩阵是否有链式法则笔者都没有了解,所以决定不以矩阵为变量尝试推导。
方法概述
将矩阵乘法理解为相应标量元素乘法和加法运算,再将得到的结果进行排列。
所以使用标量(矩阵中的元素)作为节点的输入输出,对计算图得到的反向传播结果进行排列,由此得到反向传播结果的矩阵,观察矩阵的形式,得到需要推导的公式。
解释开始
沿用了书中的例子,以下过程不能算作证明,只是对书中公式的一种解释。
Y
=
X
⋅
W
+
B
{\bf{Y=X\cdot W+B}}
Y=X⋅W+B
[
y
1
y
2
y
3
]
=
[
x
1
x
2
]
[
w
11
w
12
w
13
w
21
w
22
w
23
]
+
[
b
1
b
2
b
3
]
\begin{bmatrix} y_1 & y_2 & y_3 \\ \end{bmatrix}= \begin{bmatrix} x_1 & x_2 \\ \end{bmatrix} \begin{bmatrix} w_{11} & w_{12} & w_{13} \\ w_{21} & w_{22} & w_{23} \\ \end{bmatrix}+ \begin{bmatrix} b_1 & b_2 & b_3 \\ \end{bmatrix}
[y1y2y3]=[x1x2][w11w21w12w22w13w23]+[b1b2b3]
{
y
1
=
x
1
w
11
+
x
2
w
21
+
b
1
y
2
=
x
1
w
12
+
x
2
w
22
+
b
2
y
3
=
x
1
w
13
+
x
2
w
23
+
b
3
\begin{cases} y_1=x_1w_{11}+x_2w_{21}+b_1 \\ y_2=x_1w_{12}+x_2w_{22}+b_2 \\ y_3=x_1w_{13}+x_2w_{23}+b_3 \end{cases}
⎩⎪⎨⎪⎧y1=x1w11+x2w21+b1y2=x1w12+x2w22+b2y3=x1w13+x2w23+b3
1.绘制计算图
根据第一个等式可以画出以下计算图(令反向传播的输入为
z
1
z_1
z1,即
∂
L
∂
y
1
=
z
1
\frac{\partial L}{\partial y_1}=z_1
∂y1∂L=z1,并不假设为1,为了最后由矩阵得出公式更加方便):
可 以 得 到 ∂ L ∂ w 11 = x 1 z 1 , ∂ L ∂ w 21 = x 2 z 1 可以得到 {\color{red}\frac{\partial L}{\partial w_{11}}=x_1z_1,\frac{\partial L}{\partial w_{21}}=x_2z_1} 可以得到∂w11∂L=x1z1,∂w21∂L=x2z1
将三个等式的计算图一起画出(红色标出的是反向传播的结果), ∂ L ∂ Y = Z = [ z 1 z 2 z 3 ] \frac{\partial L}{\partial {\bf{Y}}}={\bf{Z}}=\begin{bmatrix} z_1 & z_2 & z_3 \\ \end{bmatrix} ∂Y∂L=Z=[z1z2z3]:

同 理 , 也 可 以 得 到 ∂ L ∂ w 12 , ∂ L ∂ w 13 , ∂ L ∂ w 22 , ∂ L ∂ w 23 同理,也可以得到\frac{\partial L}{\partial w_{12}},\frac{\partial L}{\partial w_{13}},\frac{\partial L}{\partial w_{22}},\frac{\partial L}{\partial w_{23}} 同理,也可以得到∂w12∂L,∂w13∂L,∂w22∂L,∂w23∂L
2.排列反向传播结果
∂
L
∂
W
=
[
∂
L
∂
w
11
∂
L
∂
w
12
∂
L
∂
w
13
∂
L
∂
w
21
∂
L
∂
w
22
∂
L
∂
w
23
]
=
[
x
1
z
1
x
1
z
2
x
1
z
3
x
2
z
1
x
2
z
2
x
2
z
3
]
=
[
x
1
x
2
]
[
z
1
z
2
z
3
]
=
X
T
⋅
∂
L
∂
Y
\large \begin{aligned} \frac{\partial L}{\partial \bf{W}} & =\begin{bmatrix} \frac{\partial L}{\partial w_{11}} & \frac{\partial L}{\partial w_{12}} & \frac{\partial L}{\partial w_{13}} \\ \frac{\partial L}{\partial w_{21}} & \frac{\partial L}{\partial w_{22}} & \frac{\partial L}{\partial w_{23}} \\ \end{bmatrix} \\ & = \begin{bmatrix} x_1z_1 & x_1z_2 & x_1z_3 \\ x_2z_1 & x_2z_2 & x_2z_3 \\ \end{bmatrix} \\ & = \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} \begin{bmatrix} z_1 & z_2 & z_3 \\ \end{bmatrix} \\ & = {\bf{X}}^T \cdot \frac{\partial L}{\partial {\bf{Y}}} \end{aligned}
∂W∂L=[∂w11∂L∂w21∂L∂w12∂L∂w22∂L∂w13∂L∂w23∂L]=[x1z1x2z1x1z2x2z2x1z3x2z3]=[x1x2][z1z2z3]=XT⋅∂Y∂L
∂
L
∂
W
\frac{\partial L}{\partial \bf{W}}
∂W∂L表示标量
L
L
L对矩阵
W
{\bf W}
W求导
这样就得到了其中一个公式。这部分主要是将结果进行排列,只要认真看书,理解图中的反向传播计算不成问题。
3.新知识引入
∂
L
∂
X
=
[
∂
L
∂
x
1
∂
L
∂
x
2
]
\frac{\partial L}{\partial {\bf{X}}} = \begin{bmatrix} \frac{\partial L}{\partial x_1} & \frac{\partial L}{\partial x_2} \end{bmatrix}
∂X∂L=[∂x1∂L∂x2∂L]
再使用图2就出现问题了,图2中元素
x
x
x反向传播得到的数据是六个,而标量
L
L
L对矩阵
X
{\bf X}
X求导得到的矩阵只有两个元素,导致不知道如何排列。其实是计算图画的不合适,权重
w
w
w的六个输入数据来自矩阵
W
{\bf{W}}
W的六个元素,但
x
x
x的六个输入数据只是来自矩阵
X
{\bf{X}}
X的两个元素。更合适的计算图如下图所示:
如何计算空白节点的反向传播,这涉及到一个新的知识点,作者在书的附录A中计算Softmax层的反向传播时有提到(主要观察下图中’
/
{\bf/}
/'结点的反向传播即可)。
正向传播时若有分支流出,则反向传播时它们的反向传播的值会相加。
因此,这里分成了三支的反向传播的值 ( − t 1 S , − t 2 S , − t 3 S ) (−t_1 S, −t_2 S, −t_3 S) (−t1S,−t2S,−t3S)会被求和。然后,还要对这个相加后的值进行“/”节点的反向传播,结果为 1 S ( t 1 + t 2 + t 3 ) \frac{1}{S}(t_1+t_2+t_3) S1(t1+t2+t3) 。
同理,图3中空白节点也有分支流出,反向传播时也应该将分支流入的值相加,就可以得到:
∂
L
∂
x
1
=
w
11
z
1
+
w
12
z
2
+
w
13
z
3
∂
L
∂
x
2
=
w
21
z
1
+
w
22
z
2
+
w
23
z
3
\begin{aligned} {\color{red}\frac{\partial L}{\partial x_1}=w_{11}z_1+w_{12}z_2+ w_{13}z_3} \\ {\color{red} \frac{\partial L}{\partial x_2} = w_{21}z_1 + w_{22}z_2 + w_{23}z_3} \end{aligned}
∂x1∂L=w11z1+w12z2+w13z3∂x2∂L=w21z1+w22z2+w23z3
然后再进行排列:
∂
L
∂
X
=
[
∂
L
∂
x
1
∂
L
∂
x
2
]
=
[
w
11
z
1
+
w
12
z
2
+
w
13
z
3
w
21
z
1
+
w
22
z
2
+
w
23
z
3
]
=
[
z
1
z
2
z
3
]
[
w
11
w
21
w
12
w
22
w
13
w
23
]
=
∂
L
∂
Y
⋅
W
T
\large \begin{aligned} \frac{\partial L}{\partial {\bf{X}}} & = \begin{bmatrix} \frac{\partial L}{\partial x_1} & \frac{\partial L}{\partial x_2} \end{bmatrix} \\ & = \begin{bmatrix} w_{11}z_1+w_{12}z_2+w_{13}z_3 & w_{21}z_1+w_{22}z_2+w_{23}z_3 \end{bmatrix} \\ & = \begin{bmatrix} z_1 & z_2 & z_3 \\ \end{bmatrix} \begin{bmatrix} w_{11} & w_{21} \\ w_{12} & w_{22} \\ w_{13} & w_{23} \\ \end{bmatrix} \\ & = \frac{\partial L}{\partial {\bf{Y}}} \cdot {\bf{W}}^T \end{aligned}
∂X∂L=[∂x1∂L∂x2∂L]=[w11z1+w12z2+w13z3w21z1+w22z2+w23z3]=[z1z2z3]⎣⎢⎡w11w12w13w21w22w23⎦⎥⎤=∂Y∂L⋅WT
4.更加严谨
可能节点中没有符号(空白节点)不够严谨,因为书中并没有给出节点中没有符号的例子,可以将空白节点视为’
×
×
ב节点,另一输入为1,或者视为’
+
+
+'节点,另一输入为0。
参考文献
1.斋藤康毅《深度学习入门-基于Python的理论与实现》
之前由于本人的疏忽,解释过程存在一些数学符号错误,现已修改过来,对于给大家理解时带来的误导和不便十分抱歉。
感谢 肌肉虾 指出本文中的存在错误!