【Spark ML系列】 xgboost原理源码分析

本文对Spark XGBoost进行原理与源码分析。介绍了XGBoost分类模型代码,包括构造函数、nativeBooster方法等,实现了训练、预测和保存功能。还分析了Booster类方法,如predict、saveModel等,提供预测、保存加载及特征重要性评估功能。

Spark xgboost原理源码分析

文章目录

  • Spark xgboost原理源码分析
    • 方法总结
    • 中文源码
      • XGBoostClassificationModel
      • BoosterParams
      • Booster

方法总结

该代码是XGBoost分类模型的实现,它继承了ProbabilisticClassificationModel类,并实现了XGBoostClassifierParams、InferenceParams、MLWritable和Serializable接口。主要包括以下几个部分:

  1. 构造函数:有两个构造函数,一个是带参数的构造函数,用于初始化模型的uid、numClasses和_booster;另一个是无参构造函数,在内部调用带参数的构造函数,默认设置numClasses为2。

  2. nativeBooster方法:返回模型的_native booster实例,用于调用底层API。

  3. trainingSummary变量和summary方法:trainingSummary保存训练摘要信息,summary方法用于获取模型在训练集上的摘要信息。

  4. 一些设置方法:包括设置leafPredictionCol、contribPredictionCol、treeLimit、missing、allowZeroForMissingValue和inferBatchSize等参数。

  5. predict方法:用于对单个实例进行预测,根据numClasses的取值情况,返回相应的预测结果。

  6. predictRaw方法和raw2probabilityInPlace方法:这两个方法并没有被使用,只是为了通过编译器的检查。

  7. transformInternal方法:用于对数据集进行转换,生成预测结果的DataFrame。

  8. produceResultIterator方法:用于生成预测结果的迭代器。

  9. generateResultSchema方法:根据固定的Schema生成最终的结果Schema。

  10. producePredictionItrs方法:生成预测结果的迭代器数组。

  11. transform方法:对数据集进行转换,生成包含预测结果的DataFrame。

  12. copy方法:复制模型。

  13. write方法:返回XGBoostClassificationModelWriter实例,用于保存模型。

总体来说,该代码实现了XGBoost分类模型的训练、预测和保存等功能,并提供了一些设置方法和获取摘要信息的方法。

中文源码

XGBoostClassificationModel

import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.ml.classification.{
   
   ProbabilisticClassificationModel, ProbabilisticClassifier}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util.MLWritable
import org.apache.spark.sql.Dataset

class XGBoostClassificationModel private[xgboost](
    override val uid: String,
    private[xgboost] val numClasses: Int,
    @transient private[xgboost] val _booster: Booster)
  extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
  with XGBoostClassifierParams with InferenceParams with MLWritable with Serializable {
   
   

  // 构造函数,用于初始化模型的uid、numClasses和_booster
  def this(uid: String, _booster: Booster) =
    this(uid, 2, _booster)

  // 返回_native booster实例,用于调用底层API
  private[xgboost] def nativeBooster: Booster = _booster

  var trainingSummary: Option[XGBoostTrainingSummary] = None

  // 获取训练集上的摘要信息
  def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
   
   
    throw new RuntimeException("No training summary available for this XGBoostClassificationModel")
  }

  // 设置leafPredictionCol参数
  def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)

  // 设置contribPredictionCol参数
  def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value)

  // 设置treeLimit参数
  def setTreeLimit(value: Int): this.type = set(treeLimit, value)

  // 设置missing参数
  def setMissing(value: Float): this.type = set(missing, value)

  // 设置allowZeroForMissingValue参数
  def setAllowZeroForMissingValue(value: Boolean): this.type = set(allowZeroForMissingValue, value)

  // 设置inferBatchSize参数
  def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)

  // 对单个实例进行预测,返回预测结果
  override protected def predict(features: Vector): Double = {
   
   
    val dMatrix = new DMatrix(features.asML, missingValue = getMissing)
    val prediction = nativeBooster.predict(dMatrix, outPutMargin = true)
    if (numClasses == 2) {
   
   
      prediction.head
    } else {
   
   
      prediction.zipWithIndex.maxBy(_._1)._2.toDouble
    }
  }

  // 以下两个方法并没有被使用,只是为了通过编译器的检查
  private def predictRaw(features: Vector): Vector = {
   
   
    throw new RuntimeException("XGBoostClassificationModel doesn't support predictRaw")
  }

  private def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
   
   
    throw new RuntimeException("XGBoostClassificationModel doesn't support raw2probabilityInPlace")
  }

  // 对数据集进行转换,生成预测结果的DataFrame
  override protected def transformInternal(dataset: Dataset[_]): DataFrame = {
   
   
    val outputSchema = generateResultSchema(dataset.schema)
    producePredictionItrs(dataset).zipWithIndex.foldLeft(dataset.toDF()) {
   
   
      case (df, (it, index)) =>
        val resultColName = if (getOutputCols.length > 1) {
   
   
          getOutputCols(index)
        } else {
   
   
          getOutputCol
        }
        df.withColumn(resultColName, it)
    }.select(getInputCols.map(dataset.col) ++ getOutputCols: _*).toDF().transform {
   
   
      df =>
        transformSchema(df.schema, logging = true)
        df
    }
  }

  // 生成预测结果的迭代器
  private def produceResultIterator(
      originalRowItr: Iterator[Row],
      rawPredictionItr: Iterator[Row],
      probabilityItr: Iterator[Row],
      predLeafItr: Iterator[Row],
      predContribItr: Iterator[Row]): Iterator[Row] = {
   
   
    // the following implementation is to be improved

    // 检查是否定义了leafPredictionCol和contribPredictionCol参数,并且它们非空
    if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
      isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
   
   
      // 如果都定义了,则将原始数据、原始预测、概率、叶子节点预测和贡献度合并为一行数据
      originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).zip(predContribItr).
        map {
   
    case ((((originals: Row, rawPrediction: Row), probability: Row), leaves: Row),
        contribs: Row) =>
          Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq ++
            contribs.toSeq)
      }
    } else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
      (!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) {
   
   
      // 如果只定义了leafPredictionCol参数,则将原始数据、原始预测、概率和叶子节点预测合并为一行数据
      originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).
        map {
   
    case (((originals: Row, rawPrediction: Row), probability: Row), leaves: Row) =>
          Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BigDataMLApplication

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值