XGBoost解析系列--源码主流程

本文详细解析了XGBoost的源码主流程,从入口过程到训练过程,深入探讨了Train主框架、UpdateOneIter流程,包括LazyInitDMatrix、PredictRaw、obj_->GetGradient和gbm_->DoBoost等步骤。文中还介绍了数据加载、模型初始化、特征列迭代器构建、预测与梯度计算等关键环节,为读者提供了深入理解XGBoost内部实现的线索。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >


0.前言

  本文介绍XGBoost的源代码流程,先梳理源码主干流程,方便读者理解,结合函数名进行说明具体逻辑与功能。如果读者对底层实现感兴趣,后续会有细节具体专题,默认读者已读过《 XGBoost解析系列-准备》《XGBoost解析系列-原理》。未来还会有分布式实现内容,这部分才是hard work。本文暂时先介绍单机版本实现,这样才更容易理解复杂的分布式版本。

1.入口过程

  重要说明:本文笔者基于xgboost git库的log commit:d9d5293cdbbbf67dc8ff9d4a3f171d0990fdd1ee (2017.10.26 17:31:10提交的commit),以demo文件夹binary_classification案例为例作为说明例子,对应的配置文件为:mushroom.conf,默认读者阅读过《 XGBoost解析系列-准备》,并将相应的配置修改正确,否则下面运行会出错。进入xgboost主目录,运行命令:

./xgboost demo/binary_classification/mushroom.conf

  该语句会运行整个xgboost框架,程序入口函数为cli_main.ccmain函数,调用xgboost::CLIRunTask(argc, argv); 主过程很简单,依次执行:

  1. 判断参数合法性,不合法直接退出;
  2. 调用rabit::Init初始化整个框架的分布特性,rabit是分布式通信库。
  3. 将配置文件通过kv形式读入vector<pair<string, string> >cfg变量,实现基于common::ConfigIterator继承于ConfigStreamReader,上层父类ConfigReaderBase,实现核心接口Next()函数,调用GetNextToken解析每行的token,分别解析出参数namevalue
  4. 使用sscanf读入命令行参数到cfg变量。
  5. 将cfg变量初始化CLIParam对象,继承于参数类模板dmlc::Parameter,后续会有专题介绍模板宏定义参数。
  6. 根据task类型执行核心流程:训练train、dump模型DumpModel、预测Predict。其中主目录下demo/binary_classification/runexp.sh脚本有完整例子,默认为训练train。
  7. 程序退出时调用rabit::Finalize 释放资源

2.Train过程

2.1 Train主框架

  xgboost最核心部分当属Train过程,训练进入核心函数CLITrain(const CLIParam& param),根据上层实例化param进行训练控制。主要执行以下:

  1. rabit::IsDistributed()判断是否为分布式模式,若是,则打印相关log信息。
  2. DMatrix::Load()根据数据文件路径,支持本地路径与分布式的hadoop文件路径,生成统一URI路径格式;解析底层数据格式,支持libsvm、libffm。解析生成数据源对象,再加载到内存数据CSR格式对象,即DMatrix对象,专题详细介绍查看此处。加载训练数据与预测数据到dtraindeval(支持多份eval数据),并合并到cache_matscache_mats用于构建全局特征统计信息,合并是因为deval中的数据可能超过dtrain的边界,或者dtrain不存在。接着将dtrain加入deval,因为训练集也需要评估。
  3. 使用cache_mats实例化learner,基于C++工厂设计模式,使用Learner::Create()静态方法实例化出具体实现类LearnerImplLearner 继承于rabit::Serializable,分布式下具备相互通信功能。
  4. 检查当前版本号rabit::LoadCheckPoint(); 若为0,为初始状态,判断param.model_in是否存在dump模型路径,存在则使用dump模型来load()初始化learner对象,通过Configure()初始化模型。否则,通过Configure()初始化learner,再调用InitModel()初始化内部模型。两者都需要调用Configure(),根据配置信息初始化成员,过程如下:

    1)构建map<string, string> cfg_,设置objective、booster、updater、predictor等配置项
    2)调用InitAllowUnknown()构建训练参数对象LearnerTrainParam
    3)不存在模型会调用InitAllowUnknown()构建模型参数对象LearnerModelParam mparam
    4) 前者调用,后者不调用:前者load模型直接实例化boost框架组件实例与目标函数组件实例,通过指针gbm_、obj_调用初始化,而后者在InitModel()才生成实例,因此会跳过。

  InitModel()采用懒惰方式初始化,LazyInitModel()过程如下:

    1) 遍历cache_mats数据得到本地特征数最大值,调用rabit::Allreduce<rabit::op::Max>会上报本地特征数最大值,通过Max算子计算全局最大值,利用Allreduce同步所有主机。
    2) 同样基于C++工厂设计模式,使用类静态函数Create()实例化损失函数框架组件与boost框架组件实例,后续会有专题说明。本文mushroom.conf配置中目标函数与boost框架为:
      a) objective=binary:logisticObjFunction基类使用静态函数Create()实例化损失函数框架组件对象RegLossObj<LogisticClassification>()LogisticClassification提供:i)PredTransform生成base_score作为boost初始值;ii)一阶梯度计算FirstOrderGradient与二阶梯度计算SecondOrderGradientObjFunction还实现抽象方法GetGradient()调用具体类的梯度计算方法。
      b)booster=gbtreeGradientBooster基类使用静态函数Create()实例化boost框架组件对象GBTree(),继承于抽象类GradientBooster。i)调用Configure()方法来初始化内部成员:构建GBTreeModel核心成员model_,该成员封装回归树集合vector<unique_ptr<RegTree>> trees; ii)清空updaters列表,后续会构建;iii)初始化预测器predictor:基于工厂设计模式,Predictor基类调用静态函数Create(),根据配置参数”cpu_predictor”生成对应CPUPredictor对象。

  5. 根据参数num_round执行迭代:UpdateOneIter执行单次迭代更新,EvalOneIter每次迭代后对预测数据集进行预测。迭代结束前会判断是否保存模型,最后检查rabit 同步version号,rabit::CheckPoint同步所有主机完成状态,开始下一轮的迭代。

  下面把UpdateOneIterEvalOneIter单独进行详解。

2.2 UpdateOneIter流程

  UpdateOneIter流程主要有以下几个步骤:

  1. LazyInitDMatrix(train);
  2. PredictRaw(train, &preds_);
  3. obj_->GetGradient(preds_, train->info(), iter, &gpair_);
  4. gbm_->DoBoost(train, &gpair_, obj_.get());

2.2.1 LazyInitDMatrix过程

  LazyInitDMatrix采用lazy方式构建ColIter列迭代器,加载数据进内存为CSR格式存储,但是xgboost分裂点查找是基于特征内的实例数据,因此需要将CSR格式存储转化为CSC格式存储。HaveColAccess()函数判断DMatrix对象是否存在ColIter成员,不存在则构建:
  1. 根据树构建模式tree_method,取值范围:'auto', 'approx', 'exact', 'hist', 'gpu_exact', 'gpu_hist'等,默认设置为'auto'。使用'auto'自适应到具体的算法,对于数据量小于 222

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值