MATLAB实现SVM多分类(one-vs-rest),利用自带函数fitcsvm
SVM多分类
SVM也叫支持向量机,其是一个二类分类器,但是对于多分类,SVM也可以实现。主要方法就是训练多个二类分类器。常见的有以下两种方式:
一对一(one-vs-one)
给定m个类,对m个类中的每两个类都训练一个分类器,总共的二类分类器个数为 m(m-1)/2 .比如有三个类,1,2,3,那么需要有三个分类器,分别是针对:1和2类,1和3类,2和3类。对于一个需要分类的数据x,它需要经过所有分类器的预测,最后使用投票的方式来决定x最终的类属性。
一对多(one-vs-rest)
给定m个类,需要训练m个二类分类器。其中的分类器 i 是将 i 类数据设置为类1(正类),其它所有m-1个i类以外的类共同设置为类2(负类),这样,针对每一个类都需要训练一个二类分类器,最后,我们一共有 m 个分类器。对于一个需要分类的数据 x,通常选择置信度最大的类别标记为分类结果。
fitcsvm简单介绍
在新版本中svmtrain和svmclassify函数提示已经被移除,所以我们应该跟上潮流学习使用fitcsvm。
// An highlighted block
SVMModel = fitcsvm(X,Y,'ClassNames',{
'negClass','posClass'},'Standardize',true,...
'KernelFunction','rbf','BoxConstraint',1);
简单说一下参数:
X是训练样本,nxm的矩阵,n是样本数,m是特征维数;
Y是样本标签,nx1的矩阵,n是样本数;
‘ClassNames’,{‘negClass’,‘posClass’} 为键值对参数,指定正负类别,负类名在前,正类名在后,与样本标签Y中的元素对应;
‘Standardize’,true 为键值对参数,指示软件是否应在训练分类器之前使预测期标准化!
‘KernelFunction’,‘rbf’ 为键值对参数,有3种 ‘linear’(默认), ‘gaussian’ (or ‘rbf’), ‘polynomial’
‘BoxConstraint’,1 为键值对参数,直观上可以理解为一个惩罚因子(或者说正则参数),这个参数和svmtrain里的-c是一个道理。其实际上涉及到软间隔SVM的间隔(Margin)大小。
基本思想如下:当原始数据未能呈现出较好的可分性时,算法允许其在训练集上呈现出一些误分类,matlab默认的BoxConstraint为1。框约束的数值越大,意味着惩罚力度越小,最后得到的分类超平面的间隔越小,支持向量数越多,模型越复杂。这也就是很多机器学习理论书中一开始推导的硬间隔支持向量机(Hard-Margin SVM)。因为该参数默认为1,所以使用默认参数训练时,我们采用的是软间隔SVM。
更详细的大家可以参考官方说明文档 [https://ww2.mathworks.cn/help/stats/fitcsvm.html].
代码
说一下思路:
1.我自己造的数据不用太关心,训练数据是60x2,60是样本数,2是特征数;测试数据是20x2的。
2.目标是分5类,一对多的方式,就要分别训练5个SVM模型;每个模型都是一个二分类,所以需要正、负样本的划分。我是这么做的正样本全部来自该类别,负样本从其它4个类别中随机选择,但数目与正样本相同。有了每一类的正、负样本,这就得到了训练样本X;再设定标签,我设的是+1,-1,这就得到了样本标签Y;其它参数均默认不设,这样就可以为每一类样本训练SVM模型了。
3.测试样本并不需要对每一类划分正、负样本,只要知道测试数据和样本标签即可。
4.每个测试样本在5个SVM模型中均得到一个得分score,利用最大得分判定该样本最终属于哪一类。
5.这个混淆矩阵函数confusionmat是真的好用,只需要知道真实标签和预测标签就能算出查准率(precision)、查全率(recall)和综合评价指标(F-measure)。
如图:
类别1的查准率 = a / ( a + d + g ) =a/(a+d+g) =a/(a+d+g)
类别1的查全率 = a / ( a + b + c ) =a/(a+b+c) =a/(a+b+c)
类别2的查准率 = e / ( b + e + h ) =e/(b+e+h) =e/(b+e+h)
类别2的查全率 = e / ( d + e + f ) =e/(d+e+f) =e/(d+e+f)
···
// An highlighted block
clc;
clear;
close all;
tic
fprintf('-----已开始请等待-----\n\n');
%% 造数据不用关心,直接跳过
% 造数据 20*2
data = [0.4,0.3;-0.5,0.1;-0.2,-0.3;0.5,-0.3;
2.1,1.9;1.8,2.2;1.7,2.5;2.3,1.6;
-2.2,1.6;-1.9,2.1;-1.7,2.6;-2.3,2.5;
-3.1,-1.9;-2.8,-2.1;-1.9,-2.5;-2.3,-3.2;
3.9,-3.5;2.8,-2.2;1.7,-3.1;2.5,-3.4];
data1 = data + 2.5*rand(20,2);
data2 = data + 2.5*rand(20,2);
data3 = data + 2.5*rand(20,2); data1(17:20,:);
% 训练数据
train_data = [data1(1:4,