【机器学习】02. vs平台c++随机森林实现回归训练和预测并保存xml模型

背景:项目需求,python框架只适合实现快速验证,但是算法真正部署项目中是不行的,需要将相关算法通过c++翻译并训练得到相应模型文件,并封装dll文件,本博客只实现训练和预测,dll文件详见参考文章。

前言:为保护客户数据,暂时使用鸢尾花数据集做测试;

数据集下载链接:

文件格式 iris_training.csv,iris_test.csv

链接:https://pan.baidu.com/s/1KzUwJwTgOYiy_tNUPZrCUQ 
提取码:tojv

 直接上代码

#include <iostream>
#include <fstream>
#include "opencv2/core/core.hpp"
#include "opencv2/ml/ml.hpp"
 
 
cv::Ptr <cv::ml::RTrees> model_load;
 
 
// 读取CSV文件并返回数据
float** readCSV(const char* filePath, int& rows, int& cols) {
    std::ifstream file(filePath);
    // 检查文件是否成功打开
    if (!file.is_open()) {
        std::cerr << "无法打开文件\n";
    }
    
    std::string line;
    // 跳过第一行
    getline(file, line);
 
    // 统计行和列数
    rows = 0;
    cols = 0;
    while (getline(file, line)) {
        ++rows;
        std::istringstream iss(line);
        std::string value;
        while (getline(iss, value, ',')) {
            ++cols;
        }
    }
    cols /= rows;
 
    // 重新定位文件指针到文件开头
    file.clear();
    file.seekg(0, std::ios::beg);
 
    // 跳过第一行
    getline(file, line);
 
    // 分配内存
    float** data = new float* [rows];
    for (int i = 0; i < rows; ++i) {
        data[i] = new float[cols];
    }
 
    // 读取数据
    for (int i = 0; i < rows; ++i) {
        getline(file, line);
        std::istringstream iss(line);
        std::string value;
        for (int j = 0; j < cols; ++j) {
            getline(iss, value, ',');
            data[i][j] = stof(value);
        }
    }
 
 
    return data;
}
 
 
int train(float** data, int rows, int cols) {
 
    float* data_arr = new float[rows * cols];
 
    for (int i = 0; i < rows * cols; i++) {
        data_arr[i] = data[i / cols][i % cols];
    }
 
    cv::Mat data_mat = cv::Mat(rows, cols, CV_32FC1, data_arr);
    
    //获得标签
    cv::Mat label = data_mat.col(cols - 1).clone();
 
    //获得训练特征数据
    data_mat = data_mat.colRange(0, cols - 1);
 
    //std::cout << data_mat << "\n";
    //std::cout << label << "\n";
 
    //std::cout << data_mat.size() << "\n";
    //std::cout << label.size() << "\n";
 
    cv::Ptr<cv::ml::TrainData> train_data = cv::ml::TrainData::create(data_mat, cv::ml::ROW_SAMPLE, label, cv::noArray(), cv::noArray(), cv::noArray(), cv::noArray());
    cv::Ptr<cv::ml::RTrees> model = cv::ml::RTrees::create();
 
    //树的最大可能深度
    //model->setMaxDepth(100);
    //节点最小样本数量
    //model->setMinSampleCount(5);
    //回归树的终止标准
    //model->setRegressionAccuracy(0.01f);
    //是否建立替代分裂点
    //model->setUseSurrogates(false);
    //最大聚类簇数
    //model->setMaxCategories(15);
    //先验类概率数组
    //model->setPriors(cv::Mat());
    //计算的变量重要性
    //model->setCalculateVarImportance(true);
    //树节点随机选择的特征子集的大小
    //model->setActiveVarCount(1);
 
    //训练模型
    model->train(train_data);
 
    //保存模型
    model->save("test_model.xml");
    printf("model saved success!\n");
 
    delete[] data_arr;
 
 
    return 0;
}
 
 
int init_model(const char* modelPath) {
    model_load = cv::Algorithm::load<cv::ml::RTrees>(modelPath);
    if (model_load.empty()) {
        printf("load model failed!\n");
        return -1;
    }
    
    return 0;
}
 
 
int predict(float** data, int rows, int cols) {
 
    float* data_arr = new float[rows * cols];
 
    for (int i = 0; i < rows * cols; i++) {
        data_arr[i] = data[i / cols][i % cols];
    }
 
    cv::Mat data_mat = cv::Mat(rows, cols, CV_32FC1, data_arr);
 
    //获得标签
    cv::Mat label = data_mat.col(cols - 1).clone();
 
    //获得训练特征数据
    data_mat = data_mat.colRange(0, cols - 1);
 
    //std::cout << data_mat << "\n";
    //std::cout << label << "\n";
 
    //std::cout << data_mat.size() << "\n";
    //std::cout << label.size() << "\n";
 
    cv::Ptr<cv::ml::TrainData> test_data = cv::ml::TrainData::create(data_mat, cv::ml::ROW_SAMPLE, label, cv::noArray(), cv::noArray(), cv::noArray(), cv::noArray());
 
 
    for (int i = 0; i < rows; i++) {
        cv::Mat test_data = data_mat.row(i);
        float out = model_load->predict(test_data);
        std::cout << out << "\n";
        //res[i] = out;
    }
 
 
    return 0;
}
 
 
int main()
{
    const char* trainData = "iris_training.csv";
    const char* testPath = "iris_test.csv";
    
    // 读取csv文件
    int rows, cols;
    float** data = readCSV(trainData, rows, cols);
 
    // 01 训练模型
    train(data, rows, cols);
 
    // 02 初始化
    const char* modelPath = "test_model.xml";
    init_model(modelPath);
 
    // 04 加载测试集
    float** testData = readCSV(testPath, rows, cols);
 
    // 05 预测
    predict(testData, rows, cols);
 
 
    // 释放每行的内存
    for (int i = 0; i < rows; ++i) {
        delete[] data[i];
    }
    // 释放指向每行的指针的内存
    delete[] data;
 
 
    for (int i = 0; i < rows; ++i) {
        delete[] testData[i];
    }
    delete[] testData;
 
 
    return 0;
}

关于封装dll文件,参考

vs2022环境下,使用c#调用c++生成的dll动态链接库,实现ocr和条形码的识别_vs2022 c# c++-优快云博客

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值