SVM(Support Vector Machine)作为经典的分类算法,运用广泛,可以通过正的样本和负的样本训练,从而达到预测新的样本属于正的还是负的。线性SVM实际上找到一个超平面,该超平面到正负样本的距离最大,从而实现最优的二分类。这里介绍opencv自带的例子introduction_to_svm.cpp,本人对其稍作修改,完整的代码如下:
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include "opencv2/imgcodecs.hpp"
#include <opencv2/highgui.hpp>
#include <opencv2/ml.hpp>
#include <opencv2/opencv.hpp>
using namespace cv;
using namespace cv::ml;
using namespace std;
int main(int, char**)
{
// Data for visual representation
int width = 512, height = 512;
Mat image = Mat::zeros(height, width, CV_8UC3);
// Set up training data
//! [setup1]
int labels[4] = {1, 1, -1, -1};
float trainingData[4][2] = { {501, 10}, {255, 80}, {501, 255}, {10, 501} };
//! [setup1]
//! [setup2]
Mat trainingDataMat(4, 2, CV_32FC1, trainingData);
Mat labelsMat(4, 1, CV_32SC1, labels);
//! [setup2]
// Train the SVM
//! [init]
Ptr<SVM> svm = SVM::create();
svm->setType(SVM::C_SVC);
svm->setKernel(SVM::LINEAR);
svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));
//! [init]
//! [train]
svm->train(trainingDataMat, ROW_SAMPLE, labelsMat);
//! [train]
// Show the decision regions given by the SVM
//! [show]
Vec3b green(0,255,0), blue (255,0,0);
for (int i = 0; i < image.rows; ++i)
for (int j = 0; j < image.cols; ++j)
{
Mat sampleMat = (Mat_<float>(1,2) << j,i);
float response = svm->predict(sampleMat);
if (response == 1)
image.at<Vec3b>(i,j) = green;
else if (response == -1)
image.at<Vec3b>(i,j) = blue;
}
//! [show]
// Show the training data
//! [show_data]
int thickness = -1;
int lineType = 8;
for (int i = 0; i < sizeof(trainingData)/sizeof(trainingData[0]); i++)
{
circle(
image, Point(trainingData[i][0], trainingData[i][1]), 5, labels[i] == 1 ? Scalar(255, 255, 255) : Scalar(0, 0, 0), thickness, lineType );
}
//! [show_data]
// Show support vectors
//! [show_vectors]
thickness = 2;
lineType = 8;
Mat sv = svm->getSupportVectors();
for (int i = 0; i < sv.rows; ++i)
{
const float* v = sv.ptr<float>(i);
cout << "[" << v[0] << ", " << v[1] << "]'" << endl;
}
//! [show_vectors]
vector<int> compression_params;
compression_params.push_back(IMWRITE_PNG_COMPRESSION);
compression_params.push_back(9);
try {
imwrite("result.png", image, compression_params); // save the image
}
catch (runtime_error& ex) {
fprintf(stderr, "Exception converting image to PNG format: %s\n", ex.what());
return 1;
}
imshow("SVM Simple Example", image); // show it to the user
waitKey(0);
}
运行后,得到support vector: [0.00232288, 0.00816326]'
和下图: