这几天看书的时候突然注意到了这个经典的优化方法,于是重新推导了一遍,为以后应用做参考。
背景
最小二乘法应该是我接触的最早的优化方法,也是求解线性回归的一种方法。线性回归的主要作用是用拟合的方式,求解两组变量之间的线性关系(当然也可以不是线性的,那就是另外的回归方法了)。也就是把一个系统的输出写成输入的线性组合的形式。而这组线性关系的参数求解方法,就是最小二乘法。
我们从最简单的线性回归开始,即输入和输出都是1维的。此时,最小二乘法也是最简单的。
假设有输入信号x={x0,x1,...,xt}x = \{x_0, x_1, ..., x_t\}x={x0,x1,...,xt},同时输出信号为y={y0,y1,...,yt}y = \{y_0, y_1, ..., y_t\}y={y0,y1,...,yt},我们假设输入信号xxx和输出信号yyy之间的关系可以写成如下形式:
ypre=ax+b(1)y_{pre} = ax+b \tag{1}ypre=ax+b(1)
我们需要求解最优的aaa和bbb,这里最优的含义就是,预测的最准确,也就是预测值和真实值的误差最小,即:
arg mina,b∑i=0t(yi−axi−b)2(2)arg\, min_{a, b}{\sum_{i=0}^{t}{(y_i-ax_i-b)^2}} \tag{2}argmina,bi=0∑t(yi−axi−b)2(2)
我们假设误差函数为:
err=∑i=0t(yi−axi−b)2(3)err = \sum_{i=0}^{t}{(y_i-ax_i-b)^2} \tag{3}err=i=0∑t(yi−axi−b)2(3)
errerrerr对aaa和bbb分别求偏导:
∂err∂a=∑i=0t2(axi+b−yi)∗xi(4)\frac{\partial{err}}{\partial{a}} = \sum_{i=0}^{t}{2(ax_i+b-y_i)*x_i} \tag{4}∂a∂err=i=0∑t2(axi+b−yi)∗xi(4)
∂err∂b=∑i=0t2(axi+b−yi)(5)\frac{\partial{err}}{\partial{b}} = \sum_{i=0}^{t}{2(ax_i+b-y_i)} \tag{5}∂b∂err=i=0∑t2(axi+b−yi)(5)
根据极值定理,有∂err∂a=0\frac{\partial{err}}{\partial{a}}=0∂a∂err=0,且∂err∂b=0\frac{\partial{err}}{\partial{b}}=0∂b∂err=0,所以有:
∑i=0t2(axi+b−yi)=0(6)\sum_{i=0}^{t}{2(ax_i+b-y_i)} = 0 \tag{6}i=0∑t2(axi+b−yi)=0(6)
∑i=0t(yi−axi)=∑i=0tb(7)\sum_{i=0}^{t}(y_i - ax_i) = \sum_{i=0}^{t}{b} \tag{7}i=0∑t(yi−axi)=i=0∑tb(7)
∑i=0tyi−a∗∑i=0txi=(t+1)∗b(8)\sum_{i=0}^{t}{y_i} - a * \sum_{i=0}^{t}{x_i} = (t+1)*b \tag{8}i=0∑tyi−a∗i=0∑txi=(t+1)∗b(8)
b=yˉ−axˉ(9)b = \bar{y} - a\bar{x} \tag{9}b=yˉ−axˉ(9)
其中,yˉ\bar{y}yˉ表示yyy的均值,xˉ\bar{x}xˉ表示xxx的均值。将Eq(9)代入Eq(4),有:
∑i=0t2(axi+b−yi)∗xi=0(10)\sum_{i=0}^{t}{2(ax_i+b-y_i)*x_i} = 0 \tag{10}i=0∑t2(axi+b−yi)∗xi=0(10)
∑i=0taxi2+∑i=0tbxi=∑i=0tyixi(11)\sum_{i=0}^{t}{ax_i^2} + \sum_{i=0}^{t}bx_i = \sum_{i=0}^{t}{y_ix_i} \tag{11}i=0∑taxi2+i=0∑tbxi=i=0∑tyixi(11)
a∑i=0txi2+xˉ(yˉ−axˉ)=∑i=0txiyi(12)a\sum_{i=0}^{t}x_i^2 + \bar{x}(\bar{y}-a\bar{x}) = \sum_{i=0}^{t}{x_iy_i} \tag{12}ai=0∑txi2+xˉ(yˉ−axˉ)=i=0∑txiyi(12)
a(∑i=0txi2−xˉ2)=∑i=0txiyi−xˉyˉ(13)a(\sum_{i=0}^{t}{x_i^2 - \bar{x}^2}) = \sum_{i=0}^{t}{x_iy_i}-\bar{x}\bar{y} \tag{13}a(i=0∑txi2−xˉ2)=i=0∑txiyi−xˉyˉ(13)
a=∑i=0txiyi−xˉyˉ∑i=0txi2−xˉ2(14)a = \frac{\sum_{i=0}^{t}{x_iy_i}-\bar{x}\bar{y}}{\sum_{i=0}^{t}{x_i^2 - \bar{x}^2}} \tag{14}a=∑i=0txi2−xˉ2∑i=0txiyi−xˉyˉ(14)
所以Eq(14)和Eq(9)就是最简单的最小二乘法的计算方法。
然后我们进一步考虑,如果输入和输出是多维数据,要如何计算。
假设输入信号为X∈Rm∗tX \in R^{m*t}X∈Rm∗t, 输出信号为Y∈Rn∗tY \in R^{n*t}Y∈Rn∗t,那么有:
Y=W0X+B=WX1(15)Y = W_0X+B = WX_1 \tag{15}Y=W0X+B=WX1(15)
其中W0∈Rn∗mW_0 \in R^{n*m}W0∈Rn∗m是回归矩阵的系数,B∈R1∗tB \in R^{1*t}B∈R1∗t表示常数项,这里可以直接写到WWW矩阵中。W∈Rn∗(m+1)W \in R^{n*(m+1)}W∈Rn∗(m+1),X1∈R(m+1)∗tX_1 \in R^{(m+1)*t}X1∈R(m+1)∗t
X1=[x11x12...x1tx11x12...x1t⋮⋮...⋮xm1xm2...xmt11...1](16)
X_1 = \begin{bmatrix}
x_{11} &x_{12} & ... &x_{1t}\\
x_{11} &x_{12} & ... &x_{1t}\\
{\vdots} &{\vdots} &... &{\vdots}\\
x_{m1} &x_{m2} &... &x_{mt}\\
1 &1 &... &1\\
\end{bmatrix} \tag{16}
X1=x11x11⋮xm11x12x12⋮xm21...............x1tx1t⋮xmt1(16)
所以有:
argminW(Y−WX1)(17)\arg min_{W}({Y-WX_1}) \tag{17}argminW(Y−WX1)(17)
假设误差函数为EEE,则有:
E=(Y−WX1)(Y−WX1)T=YYT−WX1YT−YX1TWT+WX1X1TWT(18)E = (Y-WX_1)(Y-WX_1)^T = YY^T - WX_1Y^T-YX_1^TW^T+WX_1X_1^TW^T \tag{18}E=(Y−WX1)(Y−WX1)T=YYT−WX1YT−YX1TWT+WX1X1TWT(18)
计算EEE对WWW的偏导,则该偏导等于0:
∂E∂W=−X1YT−X1YT+2WXXT=0(19)\frac{\partial{E}}{\partial{W}} = -X_1Y^T-X_1Y^T + 2WXX^T = 0 \tag{19}∂W∂E=−X1YT−X1YT+2WXXT=0(19)
所以有:
W=(X1X1T)−1X1YT(20)W = (X_1X_1^T)^{-1}X_1Y^T \tag{20}W=(X1X1T)−1X1YT(20)
至此矩阵形式的最小二乘法(多元线性回归的参数解法)推导完成。注意这里的X1X_1X1和YYY中的数据排列方式为:每一行是一个维度的数据,每一列表示一个时间点。如果不是这么记录的话,那么公式需要加上转置。
后续会附上代码链接