对神经网络梯度反向传播的理解
想写这个话题的起因是因为在读论文时,读到了下面这句话:
we visualize gradients with respect to the square of the norm of the last convolutional layer in the network, backpropagated into the input image, and visualized as a function of training data.
翻译过来就是:
我们将损失函数对最后一个卷积层的梯度的范数的平方传递到输入图像层,并将其可视化为训练数据的函数
翻译完之后感觉很拗口,我觉得实际是在讲这么一回事:
将损失函数相对于某一参数的梯度的范数的平方,从最后一个卷积层开始反向传递,一直传递到输入图像这一层,将这时的梯度值可视化出来(如通过灰度映射)
论文中给的可视化结果如下所示:
重点关注在L-Bird数据集上的效果,梯度大的位置和目标的位置强相关。可以认为是注意力机制的某种形式。
如果想看论文原文的话, 链接在下:
https://arxiv.org/abs/1511.06789v1
其实,所谓的将梯度传递到输入层,就是反向传播的过程,通过链式法则,逐步求出损失函数关于某一参数的偏导数,在传递到输入图像这一层时的值
示例:简单的卷积神经网络
考虑一个简单的卷积神经网络(CNN)架构,结构如下:
- 输入层:输入图像,尺寸为 3×33 \times 33×3(即 X∈R3×3X \in \mathbb{R}^{3 \times 3}X∈R3×3)。
- 卷积层:使用一个 2×22 \times 22×2 的卷积核,输出通道数为 1,步幅为 1。
- 激活层:使用 ReLU 激活函数。
- 全连接层:将卷积层的输出展平,并通过一个全连接层得到最终的输出。
- 损失函数:使用均方误差(MSE)损失函数。
1. 模型定义
假设输入图像 XXX 是一个 3×33 \times 33×3 的矩阵(即 X∈R3×3X \in \mathbb{R}^{3 \times 3}X∈R3×3),卷积核 WWW 是一个 2×22 \times 22×2 的矩阵(即 W∈R2×2W \in \mathbb{R}^{2 \times 2}W∈R2×2),偏置 bbb 是一个标量。
-
卷积操作:假设输出为 2×22 \times 22×2 的特征图:
Z=Conv(X,W)+b Z = \text{Conv}(X, W) + b Z=Conv(X,W)+b -
ReLU 激活:
A=ReLU(Z) A = \text{ReLU}(Z) A=ReLU(Z) -
展平:将卷积层经过激活后的输出(即 AAA)展平为一个向量 aaa。
-
全连接层:通过一个全连接层计算输出:
y^=WfcTa+bfc \hat{y} = W_{fc}^T a + b_{fc} y^=WfcTa+bfc
其中 WfcW_{fc}Wfc 是全连接层的权重,bfcb_{fc}bfc 是偏置。 -
损失函数:假设目标值为 ttt,损失函数 LLL 是:
L=12(y^−t)2 L = \frac{1}{2} (\hat{y} - t)^2 L=21(y^−t)2
2. 前向传播
-
卷积操作:
Z=Conv(X,W)+b Z = \text{Conv}(X, W) + b Z=Conv(X,W)+b -
ReLU 激活:
A=ReLU(Z) A = \text{ReLU}(Z) A=ReLU(Z) -
展平:
a=Flatten(A) a = \text{Flatten}(A) a=Flatten(A) -
全连接层:
y^=WfcTa+bfc \hat{y} = W_{fc}^T a + b_{fc} y^=WfcTa+bfc -
计算损失:
L=12(y^−t)2 L = \frac{1}{2} (\hat{y} - t)^2 L=21(y^−t)2
3. 梯度计算
损失对全连接层参数的梯度:
-
损失对全连接层输出的梯度:
∂L∂y^=y^−t \frac{\partial L}{\partial \hat{y}} = \hat{y} - t ∂y^∂L=y^−t -
损失对全连接层权重 WfcW_{fc}Wfc 的梯度:
∂L∂Wfc=∂L∂y^⋅∂y^∂Wfc=(y^−t)⋅a \frac{\partial L}{\partial W_{fc}} = \frac{\partial L}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial W_{fc}} = (\hat{y} - t) \cdot a ∂Wfc∂L=∂y^∂L⋅∂Wfc∂y^=(y^−t)⋅a -
损失对全连接层偏置 bfcb_{fc}bfc 的梯度:
∂L∂bfc=∂L∂y^⋅∂y^∂bfc=y^−t \frac{\partial L}{\partial b_{fc}} = \frac{\partial L}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial b_{fc}} = \hat{y} - t ∂bfc∂L=∂y^∂L⋅∂bfc∂y^=y^−t
反向传播到卷积层的梯度:
-
传递到展平后的向量a:
∂L∂a=∂L∂y^⋅∂y^∂a=(y^−t)⋅Wfc \frac{\partial L}{\partial a} = \frac{\partial L}{\partial \hat{y}} \cdot \frac{\partial \hat{y}}{\partial a} = (\hat{y}-t) \cdot W_{fc} ∂a∂L=∂y^∂L⋅∂a∂y^=(y^−t)⋅Wfc -
传递到卷积层的梯度:
∂L∂Z=∂L∂A⋅∂A∂Z=∂L∂A⋅ReLU′(Z)=∂L∂a⋅ReLU′(Z) \frac{\partial L}{\partial Z} = \frac{\partial L}{\partial A} \cdot \frac{\partial A}{\partial Z} =\frac{\partial L}{\partial A} \cdot \text{ReLU}'(Z) = \frac{\partial L}{\partial a} \cdot \text{ReLU}'(Z) ∂Z∂L=∂A∂L⋅∂Z∂A=∂A∂L⋅ReLU′(Z)=∂a∂L⋅ReLU′(Z)
因为从A到a只有一个展平操作,这里认为 ∂L∂A=∂L∂a\frac{\partial L}{\partial A} = \frac{\partial L}{\partial a}∂A∂L=∂a∂L
其中,ReLU′(Z)\text{ReLU}'(Z)ReLU′(Z) 是 ReLU 激活的导数:
ReLU′(z)={1if z>00if z≤0 \text{ReLU}'(z) = \begin{cases} 1 & \text{if } z > 0 \\ 0 & \text{if } z \leq 0 \end{cases} ReLU′(z)={10if z>0if z≤0 -
损失对卷积核 WWW 的梯度:
∂L∂W=∂L∂Z⋅∂Z∂W=∂L∂Z∗X \frac{\partial L}{\partial W} = \frac{\partial L}{\partial Z} \cdot \frac{\partial Z}{\partial W}=\frac{\partial L}{\partial Z} \ast X ∂W∂L=∂Z∂L⋅∂W∂Z=∂Z∂L∗X
其中,* 表示卷积操作中的梯度传递。具体而言,可以计算每个位置的梯度,然后对整个卷积核进行求和。每一个位置可能不止加一次
这里复杂的就是∂Z∂W\frac{\partial Z}{\partial W}∂W∂Z,即求Z=Conv(X,W)+bZ=Conv(X,W)+bZ=Conv(X,W)+b 关于W,bW, bW,b的导数,因为采用滑窗方式进行卷积,每一次滑窗就会覆盖输入图像上kernel∗kernelkernel*kernelkernel∗kernel个像素,对W的导数就=这几个像素值的加和,逐步滑窗,逐步累加,直至完成卷积运算。
- 损失对卷积层偏置 bbb 的梯度:
∂L∂b=∑∂L∂Z \frac{\partial L}{\partial b} = \sum \frac{\partial L}{\partial Z} ∂b∂L=∑∂Z∂L
其中加和的个数 = 上一步中加和的次数一致
反向传播到输入图像的梯度
- 损失对输入图像 XXX 的梯度:
∂L∂X=∂L∂Z⋅∂Z∂X=∂L∂Z∗W \frac{\partial L}{\partial X} = \frac{\partial L}{\partial Z} \cdot \frac{\partial Z}{\partial X}=\frac{\partial L}{\partial Z} \ast W ∂X∂L=∂Z∂L⋅∂X∂Z=∂Z∂L∗W
这里的解释和 损失对卷积核 WWW 的梯度 相似。
总结
通过上述步骤,我们可以看到如何从损失函数开始,通过网络的各层计算梯度,并将这些梯度反向传播到输入图像。具体过程包括:
- 前向传播:计算每层的输出和损失。
- 反向传播:使用链式法则计算梯度并将其从输出层传递回输入层。