6.3 实践:曲线你和问题
6.3.1 手写高斯牛顿法
* 如何手写高斯牛顿法,如何使用对应的优化库求解这个问题。
我们设定一个曲线方程:
y
=
e
x
p
(
a
x
2
+
b
x
+
c
)
+
w
y=exp(ax^2+bx+c)+w
y=exp(ax2+bx+c)+w
- 设定如下:
- a,b,c为曲线的参数,也是我们需要迭代的参数
- w为高斯噪声,满足 w ( 0 , σ 2 w~(0,\sigma ^2 w (0,σ2
- x,y为真值参数
那么 我们的目标是要求解以下目标函数
下面的误差函数如下所示:
e
i
=
y
i
−
e
x
p
(
a
x
i
2
+
b
x
i
+
c
)
e_i=y_i-exp(ax_i^2+bx_i+c)
ei=yi−exp(axi2+bxi+c)
那么,可以求出每个误差向对于状态变量的导数 于是最后公式可以表示为如下:
- 解释:高斯牛顿法中的J是以列为主的,如下所示:
因此在这里你可以把所有
J
i
J_i
Ji排成一行,将这个方程写成矩阵形式,不过含义与求和符号是一致的。这里的
σ
2
\sigma ^2
σ2就是方差,这里的
J
i
J
i
T
J_i J_i^T
JiJiT是列乘以行。因此这里的雅可比矩阵和常规的雅可比矩阵是不同的,也就是常规雅可比矩阵的转置。常规的可以矩阵定义如下:
程序实现 gaussNewton.cpp
#include <iostream>
#include <opencv2/opencv.hpp>
#include <Eigen/Core>
#include <Eigen/Dense>
using namespace std;
using namespace Eigen;
int main(int argc,char **argv){
double ar=1.0,br=2.0,cr=1.0; //真实参数值
double ae=2.0,be=-1.0,ce=5.0; //估计参数值
int N=100; //需要100个数据点
double w_sigma=1.0; //标准差 sigma的值
double inv_sigma=1.0/w_sigma; //标准差的倒数
cv::RNG rng; //这个是OpenCV随机数产生器
vector<double> x_data,y_data; //数据
for(int i=0;i<N;i++){
double x=i/100.0;
x_data.push_back(x);
y_data.push_back(exp(ar*x*x+br*x+cr)+rng.gaussian(w_sigma*w_sigma));
} //这边的话真值产生完毕
//开始Gauss-Newton迭代
int iterations=100; //迭代次数
double cost=0,lastCost=0; //本次迭代和上次迭代的cost
chrono::steady_clock::time_point t1=chrono::steady_clock::now();
for(int iter=0;iter<iterations;iter++){
Matrix3d H=Matrix3d::Zero(); //new Hessian=J^T W^{-1} J inGauss-Newton ,Some book describe it as J W^{-1} J^T
Vector3d b=Vector3d::Zero(); //bias
cost=0;
for(int i=0;i<N;i++){
double xi=x_data[i],yi=y_data[i]; //第i个数据点
double error=yi-exp(ae*xi*xi+be*xi+ce);
Vector3d J; //雅可比矩阵(向量) //这个向量是3行 1列的,是个列向量
J[0]=-xi*xi*exp(ae*xi*xi+be*xi+ce); //de/da
J[1]=-xi*exp(ae*xi*xi+be*xi+ce); //de/db
J[2]=-exp(ae*xi*xi+be*xi+ce); //de/dc
H+=inv_sigma*inv_sigma*J*J.transpose();
b+=-inv_sigma*inv_sigma*error*J;
cost+=error*error;
}
//求解线性方程 Hx=b
Vector3d dx=H.ldlt().solve(b); //ldlt() ??
if(isnan(dx[0])){
cout <<"result is nan!"<<endl; //防止出现未定义的值
break;
}
if(iter>0 &&cost>=lastCost){
cout<<"cost: "<<cost<<">=last cost: "<<lastCost<<", break."<<endl;
break;
}
ae +=dx[0];
be +=dx[1];
ce +=dx[2];
lastCost=cost;
cout<<"total cost: "<<cost<<",\t \tupdate: "<<dx.transpose()<<
"\t\testimate params :" <<ae <<","<<be<<","<<ce<<","<<endl;
}
chrono::steady_clock::time_point t2=chrono::steady_clock::now();
chrono::duration<double> time_used=chrono::duration_cast<chrono::duration<double>>(t2-t1);
cout<<"solve time cost= "<<time_used.count()<<" seconds. "<<endl;
cout<<"estimated abc = "<<ae<<","<<be<<","<<ce<<endl;
return 0;
}
最终效果如下所示: