SVM分类算法
(1)svm定义及特性
定义:在机器学习领域, 支持向量机SVM(Support Vector Machine)是一个有监督的学习模型,通常用来进行模式识别、分类、以及回归分析。
一般特征:
1、SVM学习问题可以表示为凸优化问题,因此可以利用已知的有效算法发现目标函数的全局最小值。而其他分类方法(如基于规则的分类器和人工神经网络)都采用一种基于贪心学习的策略来搜索假设空间,这种方法一般只能获得局部最优解。
2、SVM通过最大化决策边界的边缘来控制模型的能力。尽管如此,用户必须提供其他参数,如使用核函数类型和引入松弛变量等。
3、通过对数据中每个分类属性引入一个哑变量,SVM可以应用于分类数据。
4、SVM一般只能用在二类问题,对于多类问题效果不好。
(2)假设有函数
,满足:
并且有:
从
出发,希望达到的目标就是让训练数据中y=1的特征
,y=0的特征
,则该预测对训练集分类很好。
接下来做个变形,将使用的结果标签y=0和y = 1替换为y = -1,y = 1,然后将
()中的
替换为b,最后将
替换为
。这样就有了
。所以又有
进一步,可以将假设函数做一个简化,将其简单映射到y=-1和y=1上。映射关系如下:
(3)假设现在有一个二维平面,平面上有两种不同的数据,分别用圈和叉表示。可以用一条直线将这两类数据分开,这条直线就相当于一个超平面,超平面一边的数据点所对应的y全是 -1 ,另一边所对应的y全是1。如下图所示:
为了使圈和叉的分类更好,我们就要使得圈和叉离分界线的距离越远越好,这里我们就要引入函数间隔和几何间隔的概念
(4)函数间隔和几何间隔
定义函数间隔(用
表示)为:
此时我们可以知道,当y=1时,为了获得较大的函数间隔,需令
取得最大值;
当y=-1时,为了获得较大的函数间隔,需要
取得较大的负值。
只要
>0,则预测正确,一个大的函数间隔表示一个很确定的正确预测。
于是我们将一个超平面玉珍个训练集合的函数间隔定义为:
但是,如果我们将w和b都增加为原来的两倍,由式子:
可得,间隔增大了两倍,这样的增大是无意义的,于是我们要添加正规划条件,用
代替(w,b),有:
,这就是几何间隔。
(5)这样我们就只要让几何间隔最大,就可以找到最好的间隔分类了。
即:
令γ =1,这时就只需
取得最大值,是函数间隔最大化。
但是由于
是非凸性的,不能直接用来求解最优化问题,于是我们将最优化问题变为
优化问题变成了一个二次目标函数,就可以用二次编程解决了。
(6)对偶化
由于这个问题的特殊结构,还可以通过拉格朗日对偶性(Lagrange Duality)变换到对偶变量 (dual variable) 的优化问题,即通过求解与原问题等价的对偶问题(dual problem)得到原始问题的最优解,这就是线性可分条件下支持向量机的对偶算法,通过给每一个约束条件加上一个拉格朗日乘子(Lagrange
multiplier)
,定义拉格朗日函数(通过拉格朗日函数将约束条件融合到目标函数里去,从而只用一个函数表达式便能清楚的表达出我们的问题):

令:

而当所有约束条件都满足时,则有
,亦即最初要最小化的量。当某个约束条件不满足时,例如
,那么显然有
,因此,在要求约束条件得到满足的情况下最小化
,实际上等价于直接最小化
(当然,这里也有约束条件,就是
≥0,i=1,…,n)
,目标函数变成了:

交换以后的新问题是原始问题的对偶问题,这个新问题的最优值用
来表示。而且有
≤
,在满足某些条件的情况下,这两者相等,这个时候就可以通过求解对偶问题来间接地求解原始问题。“
≤
在满足某些条件的情况下,两者等价”,这所谓的“满足某些条件”就是要满足KKT条件。
(7)KKT条件
KKT条件是解决最优化问题的时用到的一种方法。我们这里提到的最优化问题通常是指对于给定的某一函数,求其在指定作用域上的全局最小值。
设目标函数f(x),不等式约束为g(x),有的教程还会添加上等式约束条件h(x)。此时的约束优化问题描述如下:
则我们定义不等式约束下的拉格朗日函数L,则L表达式为:
其中f(x)是原目标函数,hj(x)是第j个等式约束条件,λj是对应的约束系数,gk是不等式约束,uk是对应的约束系数。
此时若要求解上述优化问题,必须满足下述条件(也是我们的求解条件):
这些求解条件就是KKT条件。(1)是对拉格朗日函数取极值时候带来的一个必要条件,(2)是拉格朗日系数约束(同等式情况),(3)是不等式约束情况,(4)是互补松弛条件,(5)、(6)是原约束条件。
(8)SVM分类算法Matlab代码:
clear;%清屏
clc;
X =textread ('D:\matlab.anzhuang\bin\date.txt');
n = length(X);%总样本数量
y = X(:,4);%类别标志
X = X(:,1:3);
TOL = 0.000001;%精度要求
C = 1;%参数,对损失函数的权重
b = 0;%初始设置截距b
Wold = 0;%未更新a时的W(a)
Wnew = 0;%更新a后的W(a)
for i = 1 : 50%设置类别标志为1或者-1
y(i) = -1;
end
a = zeros(n,1);%参数a
for i = 1 : n%随机初始化a,a属于[0,C]
a(i) = 0.2;
end
%为简化计算,减少重复计算进行的计算
K = ones(n,n);
for i = 1 :n%求出K矩阵,便于之后的计算
for j = 1 : n
K(i,j) = k(X(i,:),X(j,:));
end
end
sum = zeros(n,1);%中间变量,便于之后的计算,
for k = 1 : n
for i = 1 : n
sum(k) = sum(k) + a(i) * y(i) * K(i,k);
end
end
while 1%迭代过程
%启发式选点
n1 = 1;%初始化,n1,n2代表选择的2个点
n2 = 2;
%n1按照第一个违反KKT条件的点选择
while n1 <= n
if y(n1) * (sum(n1) + b) == 1 && a(n1) >= C && a(n1) <= 0
break;
end
if y(n1) * (sum(n1) + b) > 1 && a(n1) ~= 0
break;
end
if y(n1) * (sum(n1) + b) < 1 && a(n1) ~=C
break;
end
n1 = n1 + 1;
end
%n2按照最大化|E1-E2|的原则选取
E1 = 0;
E2 = 0;
maxDiff = 0;%假设的最大误差
E1 = sum(n1) + b - y(n1);%n1的误差
for i = 1 : n
tempSum = sum(i) + b - y(i);
if abs(E1 - tempSum)> maxDiff
maxDiff = abs(E1 - tempSum);
n2 = i;
E2 = tempSum;
end
end
%以下进行更新
a1old = a(n1);
a2old = a(n2);
KK = K(n1,n1) + K(n2,n2) - 2*K(n1,n2);
a2new = a2old + y(n2) *(E1 - E2) / KK;%计算新的a2
%a2必须满足约束条件
S = y(n1) * y(n2);
if S == -1
U = max(0,a2old - a1old);
V = min(C,C - a1old + a2old);
else
U = max(0,a1old + a2old - C);
V = min(C,a1old + a2old);
end
if a2new > V
a2new = V;
end
if a2new < U
a2new = U;
end
a1new = a1old + S * (a2old - a2new);%计算新的a1
a(n1) = a1new;%更新a
a(n2) = a2new;
%更新部分值
sum = zeros(n,1);
for k = 1 : n
for i = 1 : n
sum(k) = sum(k) + a(i) * y(i) * K(i,k);
end
end
Wold = Wnew;
Wnew = 0;%更新a后的W(a)
tempSum = 0;%临时变量
for i = 1 : n
for j = 1 : n
tempSum= tempSum + y(i )*y(j)*a(i)*a(j)*K(i,j);
end
Wnew= Wnew+ a(i);
end
Wnew= Wnew - 0.5 * tempSum;
%以下更新b:通过找到某一个支持向量来计算
support = 1;%支持向量坐标初始化
while abs(a(support))< 1e-4 && support <= n
support = support + 1;
end
b = 1 / y(support) - sum(support);
%判断停止条件
if abs(Wnew/ Wold - 1 ) <= TOL
break;
end
end
%输出结果:包括原分类,辨别函数计算结果,svm分类结果
for i = 1 : n
fprintf('第%d点:原标号 ',i);
if i <= 50
fprintf('-1');
else
fprintf(' 1');
end
fprintf(' 判别函数值%f 分类结果',sum(i) + b);
if abs(sum(i) + b - 1) < 0.5
fprintf('1\n');
else if abs(sum(i) + b + 1) < 0.5
fprintf('-1\n');
else
fprintf('归类错误\n');
end
end
end