使用svm生成是否戴眼镜的模型,感觉效果还不错。
具体实现的代码如下:
#include <iostream> #include <opencv2/core.hpp> #include <opencv2/imgproc.hpp> #include "opencv2/imgcodecs.hpp" #include <opencv2/highgui.hpp> #include <opencv2/ml.hpp>
using namespace cv; using namespace cv::ml; using namespace std;
int main(int argc ,char ** argv) { ///////////////// string filename = argv[1];////the name of cvs file string testFileName = argv[2];
//--------------------- 1. Set up training data randomly --------------------------------------- /////////get trainData and labels from cvs file Ptr<TrainData> Samples; Samples = TrainData::loadFromCSV(filename.c_str(), 0); Mat trainDataGet = Samples->getTrainSamples(); Mat labelsGet = Samples->getTrainResponses(); Mat trainData(trainDataGet.rows, trainDataGet.cols, CV_32FC1); Mat labels(labelsGet.rows, labelsGet.cols, CV_32SC1);
trainDataGet.convertTo(trainData, CV_32FC1); labelsGet.convertTo(labels, CV_32SC1); printf("\n%d,%d,%d\n", labels.rows, labels.cols, labels.channels()); printf("\n%d,%d,%d\n", trainData.rows, trainData.cols, trainData.channels());
// cout<<labelsGet<<endl; // cout<<labels<<endl;
//------------------------ 2. Set up the support vector machines parameters -------------------- //------------------------ 3. Train the svm ---------------------------------------------------- cout << "Starting training process" << endl; //! [init] Ptr<SVM> svm = SVM::create(); svm->setType(SVM::C_SVC); svm->setGamma(3); svm->setC(0.1); svm->setKernel(SVM::LINEAR); svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, (int)1e7, 1e-6)); //! [init] //! [train] svm->train(trainData, ROW_SAMPLE, labels); //! [train] cout << "Finished training process" << endl;
//--------------------- 4. Set up predict data randomly --------------------------------------- /////////get trainData and labels from cvs file Ptr<TrainData> SamplesTest; SamplesTest = TrainData::loadFromCSV(testFileName.c_str(), 0); Mat TestDataGet = SamplesTest->getTrainSamples(); Mat TestData(TestDataGet.rows, TestDataGet.cols, CV_32FC1);
TestDataGet.convertTo(TestData, CV_32FC1); printf("\n%d,%d,%d\n", TestData.rows, TestData.cols, TestData.channels());
//------------------------ 5. predict and print results ---------------------------------------- for(int i = 0; i < TestData.rows; i++) { float response = svm->predict(TestData.row(i));
printf("*****%d********%f \n", i, response); } //------------------------ 6. save model ---------------------------------------- FileStorage modelFile("./svm.yml", FileStorage::WRITE); svm->write(modelFile);
waitKey(0); }
|
其中,第一个入参是训练集的文件名,第二个入参是测试集的文件名。