背景:项目需求,python框架只适合实现快速验证,但是算法真正部署项目中是不行的,需要将相关算法通过c++翻译并训练得到相应模型文件,并封装dll文件,本博客只实现训练和预测,dll文件详见参考文章。
前言:为保护客户数据,暂时使用鸢尾花数据集做测试;
数据集下载链接:
文件格式 iris_training.csv,iris_test.csv
链接:https://pan.baidu.com/s/1KzUwJwTgOYiy_tNUPZrCUQ
提取码:tojv
直接上代码
#include <iostream>
#include <fstream>
#include "opencv2/core/core.hpp"
#include "opencv2/ml/ml.hpp"
cv::Ptr <cv::ml::RTrees> model_load;
// 读取CSV文件并返回数据
float** readCSV(const char* filePath, int& rows, int& cols) {
std::ifstream file(filePath);
// 检查文件是否成功打开
if (!file.is_open()) {
std::cerr << "无法打开文件\n";
}
std::string line;
// 跳过第一行
getline(file, line);
// 统计行和列数
rows = 0;
cols = 0;
while (getline(file, line)) {
++rows;
std::istringstream iss(line);
std::string value;
while (getline(iss, value, ',')) {
++cols;
}
}
cols /= rows;
// 重新定位文件指针到文件开头
file.clear();
file.seekg(0, std::ios::beg);
// 跳过第一行
getline(file, line);
// 分配内存
float** data = new float* [rows];
for (int i = 0; i < rows; ++i) {
data[i] = new float[cols];
}
// 读取数据
for (int i = 0; i < rows; ++i) {
getline(file, line);
std::istringstream iss(line);
std::string value;
for (int j = 0; j < cols; ++j) {
getline(iss, value, ',');
data[i][j] = stof(value);
}
}
return data;
}
int train(float** data, int rows, int cols) {
float* data_arr = new float[rows * cols];
for (int i = 0; i < rows * cols; i++) {
data_arr[i] = data[i / cols][i % cols];
}
cv::Mat data_mat = cv::Mat(rows, cols, CV_32FC1, data_arr);
//获得标签
cv::Mat label = data_mat.col(cols - 1).clone();
//获得训练特征数据
data_mat = data_mat.colRange(0, cols - 1);
//std::cout << data_mat << "\n";
//std::cout << label << "\n";
//std::cout << data_mat.size() << "\n";
//std::cout << label.size() << "\n";
cv::Ptr<cv::ml::TrainData> train_data = cv::ml::TrainData::create(data_mat, cv::ml::ROW_SAMPLE, label, cv::noArray(), cv::noArray(), cv::noArray(), cv::noArray());
cv::Ptr<cv::ml::RTrees> model = cv::ml::RTrees::create();
//树的最大可能深度
//model->setMaxDepth(100);
//节点最小样本数量
//model->setMinSampleCount(5);
//回归树的终止标准
//model->setRegressionAccuracy(0.01f);
//是否建立替代分裂点
//model->setUseSurrogates(false);
//最大聚类簇数
//model->setMaxCategories(15);
//先验类概率数组
//model->setPriors(cv::Mat());
//计算的变量重要性
//model->setCalculateVarImportance(true);
//树节点随机选择的特征子集的大小
//model->setActiveVarCount(1);
//训练模型
model->train(train_data);
//保存模型
model->save("test_model.xml");
printf("model saved success!\n");
delete[] data_arr;
return 0;
}
int init_model(const char* modelPath) {
model_load = cv::Algorithm::load<cv::ml::RTrees>(modelPath);
if (model_load.empty()) {
printf("load model failed!\n");
return -1;
}
return 0;
}
int predict(float** data, int rows, int cols) {
float* data_arr = new float[rows * cols];
for (int i = 0; i < rows * cols; i++) {
data_arr[i] = data[i / cols][i % cols];
}
cv::Mat data_mat = cv::Mat(rows, cols, CV_32FC1, data_arr);
//获得标签
cv::Mat label = data_mat.col(cols - 1).clone();
//获得训练特征数据
data_mat = data_mat.colRange(0, cols - 1);
//std::cout << data_mat << "\n";
//std::cout << label << "\n";
//std::cout << data_mat.size() << "\n";
//std::cout << label.size() << "\n";
cv::Ptr<cv::ml::TrainData> test_data = cv::ml::TrainData::create(data_mat, cv::ml::ROW_SAMPLE, label, cv::noArray(), cv::noArray(), cv::noArray(), cv::noArray());
for (int i = 0; i < rows; i++) {
cv::Mat test_data = data_mat.row(i);
float out = model_load->predict(test_data);
std::cout << out << "\n";
//res[i] = out;
}
return 0;
}
int main()
{
const char* trainData = "iris_training.csv";
const char* testPath = "iris_test.csv";
// 读取csv文件
int rows, cols;
float** data = readCSV(trainData, rows, cols);
// 01 训练模型
train(data, rows, cols);
// 02 初始化
const char* modelPath = "test_model.xml";
init_model(modelPath);
// 04 加载测试集
float** testData = readCSV(testPath, rows, cols);
// 05 预测
predict(testData, rows, cols);
// 释放每行的内存
for (int i = 0; i < rows; ++i) {
delete[] data[i];
}
// 释放指向每行的指针的内存
delete[] data;
for (int i = 0; i < rows; ++i) {
delete[] testData[i];
}
delete[] testData;
return 0;
}
关于封装dll文件,参考
vs2022环境下,使用c#调用c++生成的dll动态链接库,实现ocr和条形码的识别_vs2022 c# c++-优快云博客