矩阵求导用来做什么:在机器学习中都是近世代数学运算,所以求导数也往往是对矩阵求导数。
矩阵求导的本质一个矩阵中的每一个元素对另外一个矩阵的每一个元素求导数。因此核心问题是明白求导以后,所得到的导数如何分布的。
我们由简单到复杂解释如何进行矩阵求导:
对于简单情况,x为向量y为标量,y对x每一个分量求导,不难理解有如下结论。f(x)代表着对x的运算。
x=[x1x2⋅⋅xn] y=f(x)
x= \begin{bmatrix}
x_1 \\
x_2 \\
·\\
·\\
x_n
\end{bmatrix} \,
y=f(x)
x=⎣⎢⎢⎢⎢⎡x1x2⋅⋅xn⎦⎥⎥⎥⎥⎤y=f(x)
∇xy=[dydx1dydx2⋅⋅dydxn]\nabla_{x}y= \begin{bmatrix}
\frac{\mathrm{d}y}{\mathrm{d}x_1} \\
\frac{\mathrm{d}y}{\mathrm{d}x_2}\\
·\\
·\\
\frac{\mathrm{d}y}{\mathrm{d}x_n}
\end{bmatrix}∇xy=⎣⎢⎢⎢⎢⎢⎡dx1dydx2dy⋅⋅dxndy⎦⎥⎥⎥⎥⎥⎤
同理可得x为标量y为向量的情况下
y=[f1(x),f2(x),~,fn(x)]
y=\begin{bmatrix}f_1(x),f_2(x),~,f_n(x)\end{bmatrix}
y=[f1(x),f2(x),~,fn(x)]
有:
∇xy=[dy1x,dy2x,~,dy3x]
\nabla_x y=\begin{bmatrix}\frac{\mathrm{d}y_1}{x},\frac{\mathrm{d}y_2}{x},~,\frac{\mathrm{d}y_3}{x}\end{bmatrix}
∇xy=[xdy1,xdy2,~,xdy3]
对于复杂的情况,比如x和y都是向量的情况下,我们可以分开来分析,首先将y看作一个整体,对x各个分量求导数。然后根据链式法则,或者是参考反向传播,y的每个分量都需要对x的特定分量进行求导。
x=[x1x2⋅⋅xn]y=[y1y2⋅⋅yn]x= \begin{bmatrix}
x_1 \\
x_2 \\
·\\
·\\
x_n
\end{bmatrix}
y= \begin{bmatrix}
y_1 \\
y_2 \\
·\\
·\\
y_n
\end{bmatrix}x=⎣⎢⎢⎢⎢⎡x1x2⋅⋅xn⎦⎥⎥⎥⎥⎤y=⎣⎢⎢⎢⎢⎡y1y2⋅⋅yn⎦⎥⎥⎥⎥⎤
∇xy=[dydx1dydx2⋅⋅dydxn]=[dy1dx1 dy2dx1 dy3dx1~ dy2dx1dy1dx2dy2dx2dy3dx2~dy4dx2⋅⋅dy1dxn,dy2dxn,dy3dxn~dy4dxn]\nabla_{x}y= \begin{bmatrix}
\frac{\mathrm{d}y}{\mathrm{d}x_1} \\
\frac{\mathrm{d}y}{\mathrm{d}x_2}\\
·\\
·\\
\frac{\mathrm{d}y}{\mathrm{d}x_n}
\end{bmatrix}=\begin{bmatrix}
\frac{\mathrm{d}y_1}{\mathrm{d}x_1} \, \frac{\mathrm{d}y_2}{\mathrm{d}x_1} \, \frac{\mathrm{d}y_3}{\mathrm{d}x_1}~\, \frac{\mathrm{d}y_2}{\mathrm{d}x_1}\\
\frac{\mathrm{d}y_1}{\mathrm{d}x_2} \frac{\mathrm{d}y_2}{\mathrm{d}x_2} \frac{\mathrm{d}y_3}{\mathrm{d}x_2}~ \frac{\mathrm{d}y_4}{\mathrm{d}x_2}\\
·\\
·\\
\frac{\mathrm{d}y_1}{\mathrm{d}x_n} , \frac{\mathrm{d}y_2}{\mathrm{d}x_n} , \frac{\mathrm{d}y_3}{\mathrm{d}x_n} ~\frac{\mathrm{d}y_4}{\mathrm{d}x_n}
\end{bmatrix}∇xy=⎣⎢⎢⎢⎢⎢⎡dx1dydx2dy⋅⋅dxndy⎦⎥⎥⎥⎥⎥⎤=⎣⎢⎢⎢⎢⎢⎡dx1dy1dx1dy2dx1dy3~dx1dy2dx2dy1dx2dy2dx2dy3~dx2dy4⋅⋅dxndy1,dxndy2,dxndy3~dxndy4⎦⎥⎥⎥⎥⎥⎤
常用的几个结论:
dxTdx=1\frac{\mathrm{d}x^T}{\mathrm{d}x}=1dxdxT=1
∇xXTAX=(A+AT)X\nabla_{x}X^TAX=(A+A^T)X∇xXTAX=(A+AT)X
∇u(t)v(t)=u(t)∇tv(t)+∇tu(t)v(t)\nabla u(t)v(t)=u(t) \nabla_t v(t)+ \nabla_tu(t)v(t)∇u(t)v(t)=u(t)∇tv(t)+∇tu(t)v(t)
应该足够应付大部分问题了