【SVM理论到实践2】OpenCv中的支持向量机SVM源代码的解读

本文详细介绍OpenCV中支持向量机(SVM)的应用流程,包括训练样本准备、参数设置、训练过程及预测等关键步骤,并解析SVM相关类与函数的使用方法。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

/***************************************************************************************************** 
程序功能: 
        OpenCv中的支持向量机SVM源代码的解读
使用步骤:
        无论是使用OpenCv中的SVM还是使用其他库中的SVM进行分类,一般的步骤分为以下几步进行:
		     1)准备训练样本----正样本和负样本(图片数据和标签数据)
			 2)设置SVM相应的训练参数
			 3)调用训练函数开始训练SVM相应的分类器
			 4)利用训练好的SVM分类器对没有见过的图片进行预测
			 5)显示预测结果
	    OpenCv中SVM的代码流程:
			 1)获得训练样本并且制作其相应的标签(trainingDataMat,labelsMat)
			 2)设置SVM的训练学习参数CvSVMParams
			 3)调用CvSVM的训练函数train()进行训练
			 4)等待SVM相应的分类器训练成功之后,调用CvSVM的预测函数predict()对新输入的样本进行预测,并且输出
				结果类型,输入样本对应的类别
			 5)获取支持向量(CvSVM::get_support_vector_count,CvSVM::get_support_vector )
开发环境: 
       VS2012 + OpenGl(GLUT3.7) + OpenCv2.4.9 + Halcon10.0 
时间地点: 
       陕西师范大学 文津楼----2017.3.2
作    者: 
       九月 
*****************************************************************************************************/  
#include "stdafx.h"  
#include <iostream>  
#include <gl/glut.h>                                                
#include <opencv.hpp>
using namespace std;  
using namespace  cv;

/***************************************************************************************************** 
1)支持向量机参数类CvSVMParams源代码的解读:
*****************************************************************************************************/  
struct CV_EXPORTS_W_MAP CvSVMParams
{
    CvSVMParams();
    CvSVMParams( int           svm_type,          //[1]支持向量机SVM的类型(常用的SVM类型有五种)
		         int           kernel_type,      //[2]支持向量机和函数的类型(常用的和函数类型有四种)
                 double        degree,           //[3]多项式POLY核函数的参数degree
				 double        gamma,            //[4]POLY/RBF/SIGMOID这三个核函数的参数gamma
				 double        coef0,            //[5]POLY/ SIGMOID核函数的参数coef0
                 double        Cvalue,           //[6]SVM类型(C_SVC/ EPS_SVR/ NU_SVR)的参数C
				 double        nu,               //[7]SVM类型(NU_SVC/ ONE_CLASS/ NU_SVR)的参数 
				 double        p,                //[8]SVM类型(EPS_SVR)的参数
                 CvMat*         class_weights,    //[9]C_SVC中的可选权重,赋给指定的类,乘以C今后变成class_weights*C。
				                                  //错误分类处罚项。权重越大,某一类误分类数据的处罚项就越大。
				 CvTermCriteria term_crit );      //[10]SVM的迭代练习过程的中断前提

    CV_PROP_RW int         svm_type;
    CV_PROP_RW int         kernel_type;
    CV_PROP_RW double      degree;               // for poly
    CV_PROP_RW double      gamma;                // for poly/rbf/sigmoid
    CV_PROP_RW double      coef0;                // for poly/sigmoid

    CV_PROP_RW double      C;                    // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
    CV_PROP_RW double      nu;                   // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
    CV_PROP_RW double      p;                    // for CV_SVM_EPS_SVR
    CvMat*      class_weights;                    // for CV_SVM_C_SVC
    CV_PROP_RW CvTermCriteria term_crit;          // termination criteria
};
/***************************************************************************************************** 
2)下面是CvSVM这个类的源代码解读:
*****************************************************************************************************/  
class CV_EXPORTS_W CvSVM : public CvStatModel
{
public:
    // 【1】SVM的类型
    enum
	{ 
		C_SVC     =100,                    //[1]C_SVC是C类支持向量机。主要用于N分类问题,允许用错误代价
		                                   //   系数C进行不完全分类(关于C--有些人教异常值惩罚因子,但是
										   //   根据我的理解,我更喜欢将它理解为错误代价系数)错误代价系
										   //   数C越大,那么分割超平面的最大化分类间隔就会越窄,这样造
										   //   成的问题是虽然训练过程错误很小,但是预测可能会出现较多的
										   //   预测错误
		NU_SVC    =101,                    //[2]NU类支持向量机,主要用于N分类问题,n类似然不完全分类的分
		                                   //   类器,参数为nu取代C(其取值空间在0~1之间),nu越大,决策的
										   //   边界越平滑
		ONE_CLASS =102,                    //[3]单分类器,所有的训练数据取自同一个类里,然后SVM建立一个分
		                                   //   界线以分割该类在特征空间中所占区域和其他类在特征空间中的所
										   //   占区域
		EPS_SVR   =103,                    //[4]epsilon类支持向量机,训练集中的特征向量和拟合出来的超平面
		                                   //   的距离要小于p,错误代价系数C被采用
		NU_SVR    =104                     //[5]
	};

    //【2】SVM中常见的核函数类型
    enum 
	{
		LINEAR  = 0,                       //[1]线性核函数 
		POLY    = 1,                       //[2]多项式核函数
		RBF     = 2,                       //[3]径向基核函数
		SIGMOID = 3                        //[4]SIGMOD核函数
	};

    //【3】SVM的参数类型
    enum 
	{
		C     = 0, 
		GAMMA = 1,
		P     = 2,
		NU    = 3,
		COEF  = 4,
		DEGREE= 5
	};

    CV_WRAP CvSVM();                                          //[1]SVM的构造函数
    virtual ~CvSVM();                                        //[2]析构函数

    CvSVM( const CvMat* trainData,                           //[3]带参数的构造函数
		   const CvMat* responses,
           const CvMat* varIdx=0, 
		   const CvMat* sampleIdx=0,
           CvSVMParams params=CvSVMParams() );

    virtual bool train( const CvMat* trainData,             //[4]SVM的训练函数
		                 const CvMat* responses,
                         const CvMat* varIdx=0, 
						 const CvMat* sampleIdx=0,
                        CvSVMParams params=CvSVMParams() );
	                                                          //[5]自动训练函数
    virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
        const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
        int kfold = 10,
        CvParamGrid Cgrid      = get_default_grid(CvSVM::C),
        CvParamGrid gammaGrid  = get_default_grid(CvSVM::GAMMA),
        CvParamGrid pGrid      = get_default_grid(CvSVM::P),
        CvParamGrid nuGrid     = get_default_grid(CvSVM::NU),
        CvParamGrid coeffGrid  = get_default_grid(CvSVM::COEF),
        CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
        bool balanced=false );
	                                                          //[6]SVM的预测函数
    virtual float predict( 
		                     const CvMat* sample,            
		                     bool returnDFVal=false 
						   ) const;
    virtual float predict( 
		                     const CvMat* samples, 
		                     CV_OUT CvMat* results 
						   ) const;

    CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
          const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
          CvSVMParams params=CvSVMParams() );

    CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
                       const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
                       CvSVMParams params=CvSVMParams() );

    CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
                            const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
                            int k_fold = 10,
                            CvParamGrid Cgrid      = CvSVM::get_default_grid(CvSVM::C),
                            CvParamGrid gammaGrid  = CvSVM::get_default_grid(CvSVM::GAMMA),
                            CvParamGrid pGrid      = CvSVM::get_default_grid(CvSVM::P),
                            CvParamGrid nuGrid     = CvSVM::get_default_grid(CvSVM::NU),
                            CvParamGrid coeffGrid  = CvSVM::get_default_grid(CvSVM::COEF),
                            CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
                            bool balanced=false);
    CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
    CV_WRAP_AS(predict_all) void predict( cv::InputArray samples, cv::OutputArray results ) const;

    CV_WRAP virtual int get_support_vector_count() const;      //[7]得到支持向量点的个数
    virtual const float* get_support_vector(int i) const;     //[8]得到支持向量
    virtual CvSVMParams get_params() const { return params; };
    CV_WRAP virtual void clear();

    static CvParamGrid get_default_grid( int param_id );

    virtual void write( CvFileStorage* storage, const char* name ) const;
    virtual void read( CvFileStorage* storage, CvFileNode* node );
    CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }

protected:

    virtual bool set_params( const CvSVMParams& params );
    virtual bool train1( int sample_count, int var_count, const float** samples,
                    const void* responses, double Cp, double Cn,
                    CvMemStorage* _storage, double* alpha, double& rho );
    virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
                    const CvMat* responses, CvMemStorage* _storage, double* alpha );
    virtual void create_kernel();
    virtual void create_solver();

    virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;

    virtual void write_params( CvFileStorage* fs ) const;
    virtual void read_params( CvFileStorage* fs, CvFileNode* node );

    void optimize_linear_svm();

    CvSVMParams params;
    CvMat* class_labels;
    int var_all;
    float** sv;
    int sv_total;
    CvMat* var_idx;
    CvMat* class_weights;
    CvSVMDecisionFunc* decision_func;
    CvMemStorage* storage;

    CvSVMSolver* solver;
    CvSVMKernel* kernel;

private:
    CvSVM(const CvSVM&);
    CvSVM& operator = (const CvSVM&);
};
/***************************************************************************************************** 
3)OpenCv中SVM训练函数train()源代码的详解:
4)函数参数的详解:
       1) const cv::Mat& trainData---训练样本的数据类型必须是CV_32FC1(32位浮点型单通道),数据必须是以
	              数据必须是CV_ROW_SAMPLE的,也就是说特征点向量(训练数据)必须是以行来存储的,比如说一张
				  图片代表一个训练样本的话,那么,我们先要把这个图片拉成一个行向量;这个矩阵中的一行就
				  代表一个训练样本,1000行,就代表1000张图片,1000个训练样本
	   2)const cv::Mat& responses---标签数据,也必须为CV_32FC1类型,一般是一个列向量,列向量的每一行代
	              表一个图片的类别标签
	   3)const cv::Mat& varIdx    = cv::Mat()
	   4)const cv::Mat& sampleIdx = cv::Mat()
	   5)CvSVMParams params        = CvSVMParams()---在SVM参数设置步骤设置的参数
*****************************************************************************************************/
 CV_WRAP virtual bool train( 
	                          const cv::Mat& trainData,             
							  const cv::Mat& responses,
                              const cv::Mat& varIdx    = cv::Mat(),
							  const cv::Mat& sampleIdx = cv::Mat(),
                              CvSVMParams params        = CvSVMParams() 
							 );
 /***************************************************************************************************** 
5)OpenCv中的预测函数predict()函数详解,predict()函数的原型如下所示:
6)参数详解:
            1---const cv::Mat& sample-------待分类或者带预测的图片
			2---bool returnDFVal=false------指定返回值的类型,如果为flase,则返回类别的编号,若为ture,
			                                则说明是一个二分类问题
			3---函数的返回值----------------这个函数用来预测一个新样本的类别响应;在分类问题中,这个函
			          数返回类别的编号;在回归问题中,返回函数值;输入的预测样本必须和trainData的样本
					  大小一致,否则会出现报错。若是练习中应用了varIdx参数,必然记住在predict函数中应用
					  跟练习特点一致的特点
*****************************************************************************************************/
 CV_WRAP virtual float predict( 
	                            const cv::Mat& sample,
								bool returnDFVal=false 
								) const;
CV_WRAP_AS(predict_all) void predict(
	                            cv::InputArray samples,
								cv::OutputArray results 
								) const;
/***************************************************************************************************** 
7)OpenCv中SVM分类问题的代码流程:
      1)获得训练样本并且制作其相应的标签(trainingDataMat,labelsMat)
	  2)设置SVM的训练学习参数CvSVMParams
	  3)调用CvSVM的训练函数train()进行训练
	  4)等待SVM相应的分类器训练成功之后,调用CvSVM的预测函数predict()对新输入的样本进行预测,并且输出
	     结果类型,输入样本对应的类别
	  5)获取支持向量(CvSVM::get_support_vector_count,CvSVM::get_support_vector )
*****************************************************************************************************/
/***************************************************************************************************** 
8)OpenCv中SVM分类器如何进行多问题的分类,主要的方法有三种,如下所示:
      1)一对多的最大响应策略(one against all)
	  2)一对一的投票策略
	  3)一对一的淘汰策略
这三种具体的扩展策略请参考(数字图像处理与机器视觉---张铮)这本书的支持向量机这一章
*****************************************************************************************************/

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值