感知机算法原理
实例问题
原始问题
对偶问题
原始问题及对偶问题的C++代码
#include <iostream>
using namespace std;
#define N 3 //需分类点的个数
#define M 2 //问题的维度
#define unita 1 //参数更新速度
double x[N][M]={{3.0,3.0},{4.0,3.0},{1.0,1.0}},y[N]={1,1,-1}; //三个点及对应的标签
double gram[N][N]={0}; //gram矩阵
int checkerror(double w[2],double b) //原始问题检查是否有误分类的点并返回对应下标
{
for(int i=0;i<N;i++)
if(!(y[i]*(w[0]*x[i][0]+w[1]*x[i][1]+b)>0))
return i;
return -1;
}
int checkerror1(double af[N],double b) //对偶问题检查是否有误分类的点并返回对应下标
{
double tmp;
for(int i=0;i<N;i++)
{
tmp=0;
for(int j=0;j<N;j++)
tmp+=af[j]*y[j]*gram[i][j];
if(!(y[i]*(tmp+b)>0))
return i;
}
return -1;
}
void duiouwenti() //对偶问题
{
double af[N]={0,0,0},b=0,w[N]={0};
int cnt=0;
for(int i=0;i<N;i++) //计算gram矩阵
for(int j=0;j<N;j++)
gram[i][j]=x[i][0]*x[j][0]+x[i][1]*x[j][1];
cout<<"gram矩阵:"<<endl;
for(int i=0;i<N;i++)
{
for(int j=0;j<N;j++)
cout<<gram[i][j]<<"\t";
cout<<"\n";
}
int ii;
printf("\nk\tx?\taf1\taf2\taf3\tb\n");
while((ii=checkerror1(af,b))!=-1)
{
af[ii]+=unita; //参数alpha,b更新
b+=y[ii]*unita;
//cout<<b<<endl;
cnt++;
printf("%d\tx(%d)\t%.2lf\t%.2f\t%.2f\t%.2f\n",cnt,ii+1,af[0],af[1],af[2],b);
}
b=0;
for (int k=0;k<N ;k++ ) //根据alpha及三个点计算w
{
w[0]+=af[k]*x[k][0]*y[k];
w[1]+=af[k]*x[k][1]*y[k];
b+=af[k]*y[k];
}
printf("\nf(x)=sign(%.2lf*x1%+-.2lf*x2%+-.2lf)\n",w[0],w[1],b); //输出最终分类器
}
int main(int argc, char *argv[])
{
double w[2]={0},b=0;
int i,cnt=0;
printf("k\t\t w\t\t b\n");
while((i=checkerror(w,b))!=-1)
{
w[0]=w[0]+unita*y[i]*x[i][0]; //原始问题参数w,b更新
w[1]=w[1]+unita*y[i]*x[i][1];
b=b+unita*y[i];
cnt++;
printf("%d\t\t(%.2lf,%.2lf)\t%.3lf\n",cnt,w[0],w[1],b);
}
printf("f(x)=sign(%.2lf*x1%+-.2lf*x2%+-.2lf)\n\n",w[0],w[1],b); //输出最终问题分类器
duiouwenti();
return 0;
}
运行结果
k w b
1 (3.00,3.00) 1.000
2 (2.00,2.00) 0.000
3 (1.00,1.00) -1.000
4 (0.00,0.00) -2.000
5 (3.00,3.00) -1.000
6 (2.00,2.00) -2.000
7 (1.00,1.00) -3.000
f(x)=sign(1.00*x1+1.00*x2-3.00)
gram矩阵:
18 21 6
21 25 7
6 7 2
k x? af1 af2 af3 b
1 x(1) 1.00 0.00 0.00 1.00
2 x(3) 1.00 0.00 1.00 0.00
3 x(3) 1.00 0.00 2.00 -1.00
4 x(3) 1.00 0.00 3.00 -2.00
5 x(1) 2.00 0.00 3.00 -1.00
6 x(3) 2.00 0.00 4.00 -2.00
7 x(3) 2.00 0.00 5.00 -3.00
f(x)=sign(1.00*x1+1.00*x2-3.00)