在一些对神经网络可解释性的研究中,总是会利用到损失函数对最后一层特征图进行求梯度的操作,例如著名的Grad CAM,因此对于卷积神经网络的理解不能仅仅停留在调包的阶段,我们需要拆解开它求梯度的黑盒。
如图所示,假设有一个特征图AAA, 经过一个2×22 \times 22×2的卷积核KKK操作之后,得到一个新的特征图OOO,再将其展平后经过MLPMLPMLP得到一个长度为2的输出向量YYY。
如果想要知道特征图AAA的每个元素对最终输出的贡献大小,就需要计算出YYY对AAA中每个元素的偏导,即∂Y∂A\frac{ \partial Y }{ \partial A }∂A∂Y。
我们整理一下从特征图AAA得到输出YYY的过程,可以写为:
O=CONV(A)O=CONV(A)O=CONV(A)
Y=MLP(O)Y=MLP(O)Y=MLP(O)
因此根据链式求导法则,∂Y∂A=∂Y∂O∂O∂A\frac{ \partial Y }{ \partial A }= \frac{ \partial Y }{ \partial O} \frac{ \partial O }{ \partial A}∂A∂Y=∂O∂Y∂A∂O。
以输出Y1=68Y_1=68Y1=68为例,Y1=0∗O11+1∗O12+0∗O21+1∗O22Y_1=0*O_{11}+1*O_{12}+0*O_{21}+1*O_{22}Y1=0∗O11+1∗O12+0∗O21+1∗O22, 因此∂Y1∂O=[0101]\frac{ \partial Y_1 }{ \partial O }=[0 \quad1 \quad0\quad1]∂O∂Y1=[0101]
再来计算∂O∂A=[∂O11∂A11∂O11∂A12∂O11∂A13∂O11∂A21…∂O11∂A33∂O12∂A11∂O12∂A12∂O12∂A13∂O12∂A21…∂O12∂A33∂O21∂A11∂O21∂A12∂O21∂A13∂O21∂A21…∂O21∂A33∂O22∂A11∂O22∂A12∂O22∂A13∂O22∂A21…∂O22∂A33]=CT\frac{ \partial O }{ \partial A}=\begin{bmatrix} \frac{ \partial O_{11} }{ \partial A_{11}} & \frac{ \partial O_{11} }{ \partial A_{12}} & \frac{ \partial O_{11} }{ \partial A_{13}} & \frac{ \partial O_{11} }{ \partial A_{21}} & \dots & \frac{ \partial O_{11} }{ \partial A_{33}} \\ \frac{ \partial O_{12} }{ \partial A_{11}} & \frac{ \partial O_{12} }{ \partial A_{12}} & \frac{ \partial O_{12} }{ \partial A_{13}} & \frac{ \partial O_{12} }{ \partial A_{21}} & \dots & \frac{ \partial O_{12} }{ \partial A_{33}} \\ \frac{ \partial O_{21} }{ \partial A_{11}} & \frac{ \partial O_{21} }{ \partial A_{12}} & \frac{ \partial O_{21} }{ \partial A_{13}} & \frac{ \partial O_{21} }{ \partial A_{21}} & \dots & \frac{ \partial O_{21} }{ \partial A_{33}} \\ \frac{ \partial O_{22} }{ \partial A_{11}} & \frac{ \partial O_{22} }{ \partial A_{12}} & \frac{ \partial O_{22} }{ \partial A_{13}} & \frac{ \partial O_{22} }{ \partial A_{21}} & \dots & \frac{ \partial O_{22} }{ \partial A_{33}}\end{bmatrix} =C^T∂A∂O=⎣⎢⎢⎢⎡∂A11∂O11∂A11∂O12∂A11∂O21∂A11∂O22∂A12∂O11∂A12∂O12∂A12∂O21∂A12∂O22∂A13∂O11∂A13∂O12∂A13∂O21∂A13∂O22∂A21∂O11∂A21∂O12∂A21∂O21∂A21∂O22…………∂A33∂O11∂A33∂O12∂A33∂O21∂A33∂O22⎦⎥⎥⎥⎤=CT
最后将结果整合之后,再将形状变换与AAA相同即可,即[001024023]\begin{bmatrix} 0 & 0 & 1\\ 0 & 2 & 4 \\ 0 & 2 & 3\end{bmatrix}⎣⎡000022143⎦⎤。
以下是以上计算过程的代码,可以发现计算结果和推导是一致的。
import torch
import torch.nn as nn
X = torch.tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]]).reshape(1, 1, 3, 3).float()
X.requires_grad = True
kernel = torch.tensor([[0, 1],
[2, 3]]).reshape(1, 1, 2, 2).float()
conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=2, bias=False)
conv.weight.data = kernel
fc = nn.Linear(in_features=4, out_features=2, bias=False)
fc.weight.data = torch.tensor([[0, 1, 0, 1],
[1, 0, 1, 1]]).float()
print(conv(X))
O = fc(torch.flatten(conv(X), start_dim=1))
print(O)
O[0][0].backward()
print(X.grad)