假设:X的维度为s×ns\times ns×n,其中s为样本数,每个样本均展平为1×n1\times n1×n的行向量;W维度为n×on\times on×o,其中o为全连接层的输出维度;b维度为1×o1\times o1×o;Z维度为s×os\times os×o,并且Z=X×W+bZ=X\times W+bZ=X×W+b
Z=[z11…z1o…zs1…zso]
Z=\begin{bmatrix}
z_{11} & … & z_{1o} \\
& … & \\
z_{s1} & … & z_{so}
\end{bmatrix}
Z=z11zs1………z1ozso
X=[x11…x1n…xs1…xsn]
X=\begin{bmatrix}
x_{11} & … & x_{1n} \\
& … & \\
x_{s1} & … & x_{sn}
\end{bmatrix}
X=x11xs1………x1nxsn
W=[w11…w1o…wn1…wno]
W=\begin{bmatrix}
w_{11} & … & w_{1o} \\
& … & \\
w_{n1} & … & w_{no}
\end{bmatrix}
W=w11wn1………w1owno
b=[b1…bo]
b=\begin{bmatrix}
b_{1} & … & b_{o} \\
\end{bmatrix}
b=[b1…bo]
损失函数L对W的梯度:
∂L∂W=XT×∂L∂Z\frac{\partial L}{\partial W}=X^T\times\frac{\partial L}{\partial Z}∂W∂L=XT×∂Z∂L
证明:\textbf{证明:}证明:
因为
zij=∑t=1nxitwtj+bjz_{ij}=\sum_{t=1}^n{x_{it}w_{tj}+b_j}zij=t=1∑nxitwtj+bj
所以∂zij∂wkl={xik,如果 j=l0,如果 j≠l\frac{\partial z_{ij}}{\partial w_{kl}}=
\left\{
\begin{aligned}
& x_{ik}, && \text{如果 } j=l \\
& 0, && \text{如果 } j \neq l
\end{aligned}
\right.∂wkl∂zij={xik,0,如果 j=l如果 j=l
所以
∂L∂wkl=∑i=1s∑j=1o∂L∂zij∂zij∂wkl=∑i=1s∂L∂zilxik\frac{\partial L}{\partial w_{kl}} =
{\sum_{i=1}^s \sum_{j=1}^{o} \frac{\partial L}{\partial z_{ij}} \frac{\partial z_{ij}}{\partial w_{kl}}}
=\sum_{i=1}^s\frac{\partial L}{\partial z_{il}}x_{ik}∂wkl∂L=i=1∑sj=1∑o∂zij∂L∂wkl∂zij=i=1∑s∂zil∂Lxik
又因为
XT×∂L∂Z=[x11…xs1…x1n…xsn]×[∂L∂z11…∂L∂z1o…∂L∂zs1…∂L∂zso]
X^T\times\frac{\partial L}{\partial Z}=
\begin{bmatrix}
x_{11} & … & x_{s1} \\
& … & \\
x_{1n} & … & x_{sn} \\
\end{bmatrix}
\times
\begin{bmatrix}
\frac{\partial L}{\partial z_{11}} & … & \frac{\partial L}{\partial z_{1o}} \\
& … & \\
\frac{\partial L}{\partial z_{s1}} & … & \frac{\partial L}{\partial z_{so}}
\end{bmatrix}
XT×∂Z∂L=x11x1n………xs1xsn×∂z11∂L∂zs1∂L………∂z1o∂L∂zso∂L
所以
(XT×∂L∂Z)kl=∑i=1s∂L∂zilxik=∂L∂wkl(X^T\times\frac{\partial L}{\partial Z})_{kl}=\sum_{i=1}^s\frac{\partial L}{\partial z_{il}}x_{ik}=\frac{\partial L}{\partial w_{kl}}(XT×∂Z∂L)kl=i=1∑s∂zil∂Lxik=∂wkl∂L
所以
∂L∂W=XT×∂L∂Z\frac{\partial L}{\partial W}=X^T\times\frac{\partial L}{\partial Z}∂W∂L=XT×∂Z∂L
损失函数L对X的梯度:
∂L∂X=∂L∂Z×WT\frac{\partial L}{\partial X}=\frac{\partial L}{\partial Z}\times W^T∂X∂L=∂Z∂L×WT
证明:\textbf{证明:}证明:
∂L∂xij=∑l=1s∑k=1o∂L∂zlk∂zlk∂xij\frac{\partial L}{\partial x_{ij}}=\sum_{l=1}^s\sum_{k=1}^o\frac{\partial L}{\partial z_{lk}}\frac{\partial z_{lk}}{\partial x_{ij}}
∂xij∂L=l=1∑sk=1∑o∂zlk∂L∂xij∂zlk
因为
zlk=∑l=1nxltwtk+bkz_{lk}=\sum_{l=1}^n{x_{lt}w_{tk}+b_k}zlk=l=1∑nxltwtk+bk
所以∂zlk∂xij={wjk,如果 l=i0,如果 l≠i\frac{\partial z_{lk}}{\partial x_{ij}}=
\left\{
\begin{aligned}
& w_{jk}, && \text{如果 } l=i \\
& 0, && \text{如果 } l \neq i
\end{aligned}
\right.∂xij∂zlk={wjk,0,如果 l=i如果 l=i
所以∂L∂xij=∑k=1o∂L∂zikwjk\frac{\partial L}{\partial x_{ij}}=\sum_{k=1}^o\frac{\partial L}{\partial z_{ik}}w_{jk}∂xij∂L=k=1∑o∂zik∂Lwjk
而∂L∂Z×WT=[∂L∂z11…∂L∂z1o…∂L∂zs1…∂L∂zso]×[w11…wn1…w1o…wno] \frac{\partial L}{\partial Z}\times W^T= \begin{bmatrix} \frac{\partial L}{\partial z_{11}} & … & \frac{\partial L}{\partial z_{1o}} \\ & … & \\ \frac{\partial L}{\partial z_{s1}} & … & \frac{\partial L}{\partial z_{so}} \end{bmatrix} \times \begin{bmatrix} w_{11} & … & w_{n1} \\ & … & \\ w_{1o} & … & w_{no} \end{bmatrix}∂Z∂L×WT=∂z11∂L∂zs1∂L………∂z1o∂L∂zso∂L×w11w1o………wn1wno
所以(∂L∂Z×WT)ij=∑k=1o∂L∂zikwjk=∂L∂xij(\frac{\partial L}{\partial Z}\times W^T)_{ij}=\sum_{k=1}^o\frac{\partial L}{\partial z_{ik}}w_{jk}=\frac{\partial L}{\partial x_{ij}}(∂Z∂L×WT)ij=k=1∑o∂zik∂Lwjk=∂xij∂L
所以∂L∂X=∂L∂Z×WT\frac{\partial L}{\partial X}=\frac{\partial L}{\partial Z}\times W^T∂X∂L=∂Z∂L×WT
损失函数L对b的梯度:
∂L∂b=sum(∂L∂Z,axis=0)#逐列求和\frac{\partial L}{\partial b}=sum(\frac{\partial L}{\partial Z}, axis=0) \#逐列求和∂b∂L=sum(∂Z∂L,axis=0)#逐列求和
证明:\textbf{证明:}证明:
因为
∂L∂bk=∑i=1s∑j=1o∂L∂zij∂zij∂bk=∑i=1s∂L∂zik\frac{\partial L}{\partial b_{k}} =
\sum_{i=1}^s \sum_{j=1}^{o} \frac{\partial L}{\partial z_{ij}} \frac{\partial z_{ij}}{\partial b_{k}}
=\sum_{i=1}^s\frac{\partial L}{\partial z_{ik}}∂bk∂L=i=1∑sj=1∑o∂zij∂L∂bk∂zij=i=1∑s∂zik∂L
所以
∂L∂b=1s∂L∂Z=sum(∂L∂Z,axis=0)\frac{\partial L}{\partial b}=\mathbf{1}_s\frac{\partial L}{\partial Z}=sum(\frac{\partial L}{\partial Z}, axis=0)∂b∂L=1s∂Z∂L=sum(∂Z∂L,axis=0)
其中1s\mathbf{1}_s1s为s列行向量