【Spark ML系列】 xgboost原理源码分析

本文对Spark XGBoost进行原理与源码分析。介绍了XGBoost分类模型代码,包括构造函数、nativeBooster方法等,实现了训练、预测和保存功能。还分析了Booster类方法,如predict、saveModel等,提供预测、保存加载及特征重要性评估功能。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Spark xgboost原理源码分析

文章目录

  • Spark xgboost原理源码分析
    • 方法总结
    • 中文源码
      • XGBoostClassificationModel
      • BoosterParams
      • Booster

方法总结

该代码是XGBoost分类模型的实现,它继承了ProbabilisticClassificationModel类,并实现了XGBoostClassifierParams、InferenceParams、MLWritable和Serializable接口。主要包括以下几个部分:

  1. 构造函数:有两个构造函数,一个是带参数的构造函数,用于初始化模型的uid、numClasses和_booster;另一个是无参构造函数,在内部调用带参数的构造函数,默认设置numClasses为2。

  2. nativeBooster方法:返回模型的_native booster实例,用于调用底层API。

  3. trainingSummary变量和summary方法:trainingSummary保存训练摘要信息,summary方法用于获取模型在训练集上的摘要信息。

  4. 一些设置方法:包括设置leafPredictionCol、contribPredictionCol、treeLimit、missing、allowZeroForMissingValue和inferBatchSize等参数。

  5. predict方法:用于对单个实例进行预测,根据numClasses的取值情况,返回相应的预测结果。

  6. predictRaw方法和raw2probabilityInPlace方法:这两个方法并没有被使用,只是为了通过编译器的检查。

  7. transformInternal方法:用于对数据集进行转换,生成预测结果的DataFrame。

  8. produceResultIterator方法:用于生成预测结果的迭代器。

  9. generateResultSchema方法:根据固定的Schema生成最终的结果Schema。

  10. producePredictionItrs方法:生成预测结果的迭代器数组。

  11. transform方法:对数据集进行转换,生成包含预测结果的DataFrame。

  12. copy方法:复制模型。

  13. write方法:返回XGBoostClassificationModelWriter实例,用于保存模型。

总体来说,该代码实现了XGBoost分类模型的训练、预测和保存等功能,并提供了一些设置方法和获取摘要信息的方法。

中文源码

XGBoostClassificationModel

import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.ml.classification.{ProbabilisticClassificationModel, ProbabilisticClassifier}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util.MLWritable
import org.apache.spark.sql.Dataset

class XGBoostClassificationModel private[xgboost](
    override val uid: String,
    private[xgboost] val numClasses: Int,
    @transient private[xgboost] val _booster: Booster)
  extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
  with XGBoostClassifierParams with InferenceParams with MLWritable with Serializable {

  // 构造函数,用于初始化模型的uid、numClasses和_booster
  def this(uid: String, _booster: Booster) =
    this(uid, 2, _booster)

  // 返回_native booster实例,用于调用底层API
  private[xgboost] def nativeBooster: Booster = _booster

  var trainingSummary: Option[XGBoostTrainingSummary] = None

  // 获取训练集上的摘要信息
  def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
    throw new RuntimeException("No training summary available for this XGBoostClassificationModel")
  }

  // 设置leafPredictionCol参数
  def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)

  // 设置contribPredictionCol参数
  def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value)

  // 设置treeLimit参数
  def setTreeLimit(value: Int): this.type = set(treeLimit, value)

  // 设置missing参数
  def setMissing(value: Float): this.type = set(missing, value)

  // 设置allowZeroForMissingValue参数
  def setAllowZeroForMissingValue(value: Boolean): this.type = set(allowZeroForMissingValue, value)

  // 设置inferBatchSize参数
  def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)

  // 对单个实例进行预测,返回预测结果
  override protected def predict(features: Vector): Double = {
    val dMatrix = new DMatrix(features.asML, missingValue = getMissing)
    val prediction = nativeBooster.predict(dMatrix, outPutMargin = true)
    if (numClasses == 2) {
      prediction.head
    } else {
      prediction.zipWithIndex.maxBy(_._1)._2.toDouble
    }
  }

  // 以下两个方法并没有被使用,只是为了通过编译器的检查
  private def predictRaw(features: Vector): Vector = {
    throw new RuntimeException("XGBoostClassificationModel doesn't support predictRaw")
  }

  private def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
    throw new RuntimeException("XGBoostClassificationModel doesn't support raw2probabilityInPlace")
  }

  // 对数据集进行转换,生成预测结果的DataFrame
  override protected def transformInternal(dataset: Dataset[_]): DataFrame = {
    val outputSchema = generateResultSchema(dataset.schema)
    producePredictionItrs(dataset).zipWithIndex.foldLeft(dataset.toDF()) {
      case (df, (it, index)) =>
        val resultColName = if (getOutputCols.length > 1) {
          getOutputCols(index)
        } else {
          getOutputCol
        }
        df.withColumn(resultColName, it)
    }.select(getInputCols.map(dataset.col) ++ getOutputCols: _*).toDF().transform {
      df =>
        transformSchema(df.schema, logging = true)
        df
    }
  }

  // 生成预测结果的迭代器
  private def produceResultIterator(
      originalRowItr: Iterator[Row],
      rawPredictionItr: Iterator[Row],
      probabilityItr: Iterator[Row],
      predLeafItr: Iterator[Row],
      predContribItr: Iterator[Row]): Iterator[Row] = {
    // the following implementation is to be improved

    // 检查是否定义了leafPredictionCol和contribPredictionCol参数,并且它们非空
    if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
      isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
      // 如果都定义了,则将原始数据、原始预测、概率、叶子节点预测和贡献度合并为一行数据
      originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).zip(predContribItr).
        map { case ((((originals: Row, rawPrediction: Row), probability: Row), leaves: Row),
        contribs: Row) =>
          Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq ++
            contribs.toSeq)
      }
    } else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
      (!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) {
      // 如果只定义了leafPredictionCol参数,则将原始数据、原始预测、概率和叶子节点预测合并为一行数据
      originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).
        map { case (((originals: Row, rawPrediction: Row), probability: Row), leaves: Row) =>
          Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq)
        }
    } else if ((!isDefined(leafPredictionCol) || $(leafPredictionCol).isEmpty) &&
      isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
      // 如果只定义了contribPredictionCol参数,则将原始数据、原始预测、概率和贡献度合并为一行数据
      originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predContribItr).
        map { case (((originals: Row, rawPrediction: Row), probability: Row), contribs: Row) =>
          Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ contribs.toSeq)
        }
    } else {
      // 如果都未定义,则将原始数据、原始预测和概率合并为一行数据
      originalRowItr.zip(rawPredictionItr).zip(probabilityItr).map {
        case ((originals: Row, rawPrediction: Row), probability: Row) =>
          Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq)
      }
    }
  }

  private def generateResultSchema(fixedSchema: StructType): StructType = {
    var resultSchema = fixedSchema
    // 检查是否定义了leafPredictionCol参数,并且非空
    if (isDefined(leafPredictionCol)) {
      // 添加leafPredictionCol字段到resultSchema中,数据类型为ArrayType(FloatType)
      resultSchema = resultSchema.add(StructField(name = $(leafPredictionCol), dataType =
        ArrayType(FloatType, containsNull = false), nullable = false))
    }
    // 检查是否定义了contribPredictionCol参数,并且非空
    if (isDefined(contribPredictionCol)) {
      // 添加contribPredictionCol字段到resultSchema中,数据类型为ArrayType(FloatType)
      resultSchema = resultSchema.add(StructField(name = $(contribPredictionCol), dataType =
        ArrayType(FloatType, containsNull = false), nullable = false))
    }
    resultSchema
  }

  private def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix):
      Array[Iterator[Row]] = {
    // 使用广播变量的Booster进行预测,返回原始预测值的迭代器
    val rawPredictionItr = {
      broadcastBooster.value.predict(dm, outPutMargin = true, $(treeLimit)).
        map(Row(_)).iterator
    }
    // 使用广播变量的Booster进行预测,返回概率的迭代器
    val probabilityItr = {
      broadcastBooster.value.predict(dm, outPutMargin = false, $(treeLimit)).
        map(Row(_)).iterator
    }
    // 如果定义了叶子节点预测列,则使用广播变量的Booster进行预测叶子节点,返回叶子节点预测值的迭代器
    val predLeafItr = {
      if (isDefined(leafPredictionCol)) {
        broadcastBooster.value.predictLeaf(dm, $(treeLimit)).map(Row(_)).iterator
      } else {
        Iterator()
      }
    }
    // 如果定义了贡献度预测列,则使用广播变量的Booster进行预测贡献度,返回贡献度预测值的迭代器
    val predContribItr = {
      if (isDefined(contribPredictionCol)) {
        broadcastBooster.value.predictContrib(dm, $(treeLimit)).map(Row(_)).iterator
      } else {
        Iterator()
      }
    }
    // 返回包含以上四个迭代器的数组
    Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr)
  }

  override def transform(dataset: Dataset[_]): DataFrame = {
    // 确保schema匹配,打印日志
    transformSchema(dataset.schema, logging = true)
    if (isDefined(thresholds)) {
      // 如果定义了阈值,则要求阈值的长度与类别数相同
      require($(thresholds).length == numClasses, this.getClass.getSimpleName +
        ".transform() called with non-matching numClasses and thresholds.length." +
        s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
    }

    // 只输出选定的列
    // 这里有点复杂,因为它试图避免重复计算
    var outputData = transformInternal(dataset)
    var numColsOutput = 0

    // 定义UDF来处理原始预测值
    val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] =>
      val raw = rawPrediction.map(_.toDouble).toArray
      val rawPredictions = if (numClasses == 2) Array(-raw(0), raw(0)) else raw
      Vectors.dense(rawPredictions)
    }

    // 定义UDF来处理概率
    val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
      val prob = probability.map(_.toDouble).toArray
      val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
      Vectors.dense(probabilities)
    }

    // 定义UDF来处理预测值
    val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
      // 从XGBoost的概率转换为MLlib的预测值
      val prob = probability.map(_.toDouble).toArray
      val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
      probability2prediction(Vectors.dense(probabilities))
    }

    // 如果定义了原始预测列,则将原始预测值应用UDF并添加到输出数据中
    if ($(rawPredictionCol).nonEmpty) {
      outputData = outputData
        .withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol)))
      numColsOutput += 1
    }

    // 如果定义了概率列,则将概率应用UDF并添加到输出数据中
    if ($(probabilityCol).nonEmpty) {
      outputData = outputData
        .withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol)))
      numColsOutput += 1
    }

    // 如果定义了预测列,则将概率应用UDF并添加到输出数据中
       if ($(predictionCol).nonEmpty) {
      outputData = outputData
        .withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
      numColsOutput += 1
    }

    // 如果没有输出列被设置,则打印警告日志
    if (numColsOutput == 0) {
      this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
        " since no output columns were set.")
    }
    // 返回输出数据,并丢弃原始预测列和概率列
    outputData
      .toDF
      .drop(col(_rawPredictionCol))
      .drop(col(_probabilityCol))
  }

  override def copy(extra: ParamMap): XGBoostClassificationModel = {
    // 复制模型并设置额外的参数
    val newModel = copyValues(new XGBoostClassificationModel(uid, numClasses, _booster), extra)
    newModel.setSummary(summary).setParent(parent)
  }

  override def write: MLWriter =
    // 返回XGBoostClassificationModelWriter实例用于写入模型
    new XGBoostClassificationModel.XGBoostClassificationModelWriter(this)
}
   

BoosterParams

private[spark] trait BoosterParams extends Params {

  /**
   * 步长缩减系数,用于防止过拟合。每次提升步骤之后,
   * 可以直接得到新特征的权重,eta实际上会缩小特征权重,
   * 使提升过程更加保守。取值范围为[0, 1],默认值为0.3。
   */
  final val eta = new DoubleParam(this, "eta", "步长缩减系数,用于防止过拟合。" +
    "每次提升步骤之后,可以直接得到新特征的权重,并且eta实际上会缩小特征权重," +
    "使提升过程更加保守。",
    (value: Double) => value >= 0 && value <= 1)

  final def getEta: Double = $(eta)

  /**
   * 叶子节点进一步划分所需的最小损失减少量。该值越大,算法越保守。
   * 取值范围为[0, Double.MaxValue],默认值为0。
   */
  final val gamma = new DoubleParam(this, "gamma", "叶子节点进一步划分所需的最小损失减少量。" +
    "该值越大,算法越保守。",
    (value: Double) => value >= 0)

  final def getGamma: Double = $(gamma)

  /**
   * 树的最大深度,增加此值会使模型更复杂/更容易过拟合。
   * 取值范围为[1, Int.MaxValue],默认值为6。
   */
  final val maxDepth = new IntParam(this, "maxDepth", "树的最大深度,增加此值会使模型更复杂/更容易过拟合。",
    (value: Int) => value >= 0)

  final def getMaxDepth: Int = $(maxDepth)


  /**
   * 要添加的最大节点数。仅在设置了grow_policy=lossguide时相关。
   * 默认值为0。
   */
  final val maxLeaves = new IntParam(this, "maxLeaves",
    "要添加的最大节点数。仅在设置了grow_policy=lossguide时相关。",
    (value: Int) => value >= 0)

  final def getMaxLeaves: Int = $(maxLeaves)


  /**
   * 子节点中所需的最小实例权重(Hessian)总和。如果树的划分导致叶子节点的实例权重总和小于min_child_weight,
   * 则构建过程将停止进一步划分。在线性回归模式下,这简单地对应于每个节点所需的最小实例数。
   * 该值越大,算法越保守。取值范围为[0, Double.MaxValue],默认值为1。
   */
  final val minChildWeight = new DoubleParam(this, "minChildWeight", "子节点中所需的最小实例权重(Hessian)总和。" +
    "如果树的划分导致叶子节点的实例权重总和小于min_child_weight,则构建过程将停止进一步划分。" +
    "在线性回归模式下,这简单地对应于每个节点所需的最小实例数。" +
    "该值越大,算法越保守。",
    (value: Double) => value >= 0)

  final def getMinChildWeight: Double = $(minChildWeight)


    /**
     * 最大增量步长限制,用于控制每棵树权重估计的最大变化值。如果该值设置为0,表示没有约束。
     * 如果设置为正值,则可以帮助使更新步骤更加保守。通常情况下,不需要使用该参数,
     * 但在类别极度不平衡的逻辑回归中可能会有所帮助。将其设置为1-10之间的值可以帮助控制更新。
     * [默认值=0] 范围:[0, Double.MaxValue]
     */
    final val maxDeltaStep = new DoubleParam(this, "maxDeltaStep", "最大增量步长限制,用于控制每棵树权重估计的最大变化值。如果该值设置为0,表示没有约束。如果设置为正值,则可以帮助使更新步骤更加保守。通常情况下,不需要使用该参数,但在类别极度不平衡的逻辑回归中可能会有所帮助。将其设置为1-10之间的值可以帮助控制更新",
      (value: Double) => value >= 0)

    final def getMaxDeltaStep: Double = $(maxDeltaStep)

    /**
     * 训练实例的子采样比例。将其设置为0.5表示XGBoost随机收集一半的数据实例用于构建树,以防止过拟合。
     * [默认值=1] 范围:(0,1]
     */
    final val subsample = new DoubleParam(this, "subsample", "训练实例的子采样比例。将其设置为0.5表示XGBoost随机收集一半的数据实例用于构建树,以防止过拟合。",
      (value: Double) => value <= 1 && value > 0)

    final def getSubsample: Double = $(subsample)

    /**
     * 构建每棵树时列的子采样比例。[默认值=1] 范围:(0,1]
     */
    final val colsampleBytree = new DoubleParam(this, "colsampleBytree", "构建每棵树时列的子采样比例。",
      (value: Double) => value <= 1 && value > 0)

    final def getColsampleBytree: Double = $(colsampleBytree)

    /**
     * 每个级别中每次分割时列的子采样比例。[默认值=1] 范围:(0,1]
     */
    final val colsampleBylevel = new DoubleParam(this, "colsampleBylevel", "每个级别中每次分割时列的子采样比例。",
      (value: Double) => value <= 1 && value > 0)

    final def getColsampleBylevel: Double = $(colsampleBylevel)

    /**
     * 权重的L2正则化项,增加此值会使模型更加保守。[默认值=1]
     */
    final val lambda = new DoubleParam(this, "lambda", "权重的L2正则化项,增加此值会使模型更加保守。",
      (value: Double) => value >= 0)

    final def getLambda: Double = $(lambda)

    /**
     * 权重的L1正则化项,增加此值会使模型更加保守。[默认值=0]
     */
    final val alpha = new DoubleParam(this, "alpha", "权重的L1正则化项,增加此值会使模型更加保守。",
      (value: Double) => value >= 0)

    final def getAlpha: Double = $(alpha)


    /**
     * XGBoost中使用的树构建算法。选项:{'auto', 'exact', 'approx'}
     * [默认值='auto']
     */
    final val treeMethod = new Param[String](this, "treeMethod",
      "XGBoost中使用的树构建算法。选项:{'auto', 'exact', 'approx', 'hist'}",
      (value: String) => BoosterParams.supportedTreeMethods.contains(value))

    final def getTreeMethod: String = $(treeMethod)

    /**
     * 快速直方图算法的增长策略
     */
    final val growPolicy = new Param[String](this, "growPolicy",
      "控制新节点添加到树中的方式。仅当tree_method设置为hist时支持。选项:depthwise(从根节点开始拆分),lossguide(选择损失变化最大的节点进行拆分)。",
      (value: String) => BoosterParams.supportedGrowthPolicies.contains(value))

    final def getGrowPolicy: String = $(growPolicy)

    /**
     * 直方图中的最大bin数
     */
    final val maxBins = new IntParam(this, "maxBin", "直方图中的最大bin数",
      (value: Int) => value > 0)

    final def getMaxBins: Int = $(maxBins)

    /**
     * 仅用于近似贪心算法。粗略地等同于O(1/sketch_eps)个bin。与直接选择bin数相比,具有理论保证和草图准确性。
     * [默认值=0.03] 范围:(0, 1)
     */
    final val sketchEps = new DoubleParam(this, "sketchEps",
      "仅用于近似贪心算法。粗略地等同于O(1/sketch_eps)个bin。与直接选择bin数相比,具有理论保证和草图准确性。",
      (value: Double) => value < 1 && value > 0)

    final def getSketchEps: Double = $(sketchEps)

    /**
     * 控制正负权重的平衡,对于不平衡的类别很有用。一个典型的值可以考虑:负样本总数 / 正样本总数。[默认值=1]
     */
    final val scalePosWeight = new DoubleParam(this, "scalePosWeight", "控制正负权重的平衡,对于不平衡的类别很有用。一个典型的值可以考虑:负样本总数 / 正样本总数")

    final def getScalePosWeight: Double = $(scalePosWeight)

    // Dart增强器

    /**
     * Dart增强器的参数。采样算法的类型。"uniform":均匀选择被丢弃的树。"weighted":按权重比例选择被丢弃的树。[默认值="uniform"]
     */
    final val sampleType = new Param[String](this, "sampleType", "采样算法的类型,选项:{'uniform', 'weighted'}",
      (value: String) => BoosterParams.supportedSampleType.contains(value))

    final def getSampleType: String = $(sampleType)

    /**
     * Dart增强器的参数。归一化算法的类型,选项:{'tree', 'forest'}。[默认值="tree"]
     */
    final val normalizeType = new Param[String](this, "normalizeType", "归一化算法的类型,选项:{'tree', 'forest'}",
      (value: String) => BoosterParams.supportedNormalizeType.contains(value))

    final def getNormalizeType: String = $(normalizeType)

    /**
     * Dart增强器的参数。dropout率。[默认值=0.0] 范围:[0.0, 1.0]
     */
    final val rateDrop = new DoubleParam(this, "rateDrop", "dropout率",
      (value: Double) => value >= 0 && value <= 1)

    final def getRateDrop: Double = $(rateDrop)


    /**
     * Dart增强器的参数。dropout率。[默认值=0.0] 范围:[0.0, 1.0]
     */
    final val rateDrop = new DoubleParam(this, "rateDrop", "dropout率",
      (value: Double) => value >= 0 && value <= 1)

    final def getRateDrop: Double = $(rateDrop)

    /**
     * Dart增强器的参数。跳过dropout的概率。如果跳过了一个dropout,则以与gbtree相同的方式添加新树。[默认值=0.0] 范围:[0.0, 1.0]
     */
    final val skipDrop = new DoubleParam(this, "skipDrop", "跳过dropout的概率。如果跳过了一个dropout,则以与gbtree相同的方式添加新树。",
      (value: Double) => value >= 0 && value <= 1)

    final def getSkipDrop: Double = $(skipDrop)

    // 线性增强器

    /**
     * 线性增强器的参数。对偏差的L2正则化项,默认为0(没有L1正则化项,因为它不重要)
     */
    final val lambdaBias = new DoubleParam(this, "lambdaBias", "对偏差的L2正则化项,默认为0(没有L1正则化项,因为它不重要)",
      (value: Double) => value >= 0)

    final def getLambdaBias: Double = $(lambdaBias)

    final val treeLimit = new IntParam(this, name = "treeLimit",
      doc = "在预测中使用的树的数量,默认为0(使用所有树)。")

    final def getTreeLimit: Int = $(treeLimit)

    final val monotoneConstraints = new Param[String](this, name = "monotoneConstraints",
      doc = "特征数量的列表,1表示单调递增,-1表示单调递减,0表示没有约束。如果长度小于特征数量,则用0填充")

    final def getMonotoneConstraints: String = $(monotoneConstraints)

    final val interactionConstraints = new Param[String](this,
      name = "interactionConstraints",
      doc = "表示允许交互的约束的交互。约束必须以嵌套列表的形式指定,例如[[0, 1], [2, 3, 4]],其中每个内部列表是允许相互作用的特征的索引组。详细信息请参见教程")
    
    setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6,
    minChildWeight -> 1, maxDeltaStep -> 0,
    growPolicy -> "depthwise", maxBins -> 16,
    subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1,
    lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
    scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
    rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0, treeLimit -> 0)
}


private[spark] object BoosterParams {

  val supportedBoosters = HashSet("gbtree", "gblinear", "dart")

  val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist")

  val supportedGrowthPolicies = HashSet("depthwise", "lossguide")

  val supportedSampleType = HashSet("uniform", "weighted")

  val supportedNormalizeType = HashSet("tree", "forest")
}

Booster

/**
  * xgboost的增强器,这是一个支持交互式构建XGBoost模型的模型API
  *
  * 开发者警告:一个Java Booster不能被多个Scala Booster共享
  * @param booster Java Booster对象
  */
class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
  extends Serializable with KryoSerializable {

  /**
   * 获取存储在Booster中的属性,并以Map形式返回。
   *
   * @return 包含属性对的Map
   */
  @throws(classOf[XGBoostError])
  def getAttrs: Map[String, String] = {
    booster.getAttrs.asScala.toMap
  }

  /**
   * 从Booster中获取属性。
   *
   * @param key 属性名
   * @return 属性值
   */
  @throws(classOf[XGBoostError])
  def getAttr(key: String): String = {
    booster.getAttr(key)
  }

  /**
   * 将属性设置到Booster中。
   *
   * @param key   属性名
   * @param value 属性值
   */
  @throws(classOf[XGBoostError])
  def setAttr(key: String, value: String): Unit = {
    booster.setAttr(key, value)
  }

  /**
   * 设置属性。
   *
   * @param params 属性的键值对Map
   */
  @throws(classOf[XGBoostError])
  def setAttrs(params: Map[String, String]): Unit = {
    booster.setAttrs(params.asJava)
  }

  /**
    * 将参数设置到Booster中。
    *
    * @param key   参数名
    * @param value 参数值
    */
  @throws(classOf[XGBoostError])
  def setParam(key: String, value: AnyRef): Unit = {
    booster.setParam(key, value)
  }

  /**
   * 设置参数。
   *
   * @param params 参数的键值对Map
   */
  @throws(classOf[XGBoostError])
  def setParams(params: Map[String, AnyRef]): Unit = {
    booster.setParams(params.asJava)
  }

  /**
   * 更新(一次迭代)
   *
   * @param dtrain 训练数据
   * @param iter   当前迭代次数
   */
  @throws(classOf[XGBoostError])
  def update(dtrain: DMatrix, iter: Int): Unit = {
    booster.update(dtrain.jDMatrix, iter)
  }

  /**
   * 使用自定义目标函数进行更新
   *
   * @param dtrain 训练数据
   * @param obj    自定义目标函数类
   */
  @throws(classOf[XGBoostError])
  def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
    booster.update(dtrain.jDMatrix, obj)
  }

  /**
   * 使用给定的梯度和海森矩阵进行更新
   *
   * @param dtrain 训练数据
   * @param grad   梯度的一阶导数
   * @param hess   梯度的二阶导数
   */
  @throws(classOf[XGBoostError])
  def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
    booster.boost(dtrain.jDMatrix, grad, hess)
  }

  /**
   * 对给定的矩阵进行评估。
   *
   * @param evalMatrixs 评估矩阵
   * @param evalNames   评估名称,用于检查结果
   * @param iter        当前评估迭代次数
   * @return 评估信息
   */
  @throws(classOf[XGBoostError])
  def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = {
    booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
  }

  /**
     * 使用给定的数据进行预测
     *
     * @param data         存储输入数据的dmatrix
     * @param outPutMargin 是否输出未转换的原始边际值。
     * @param treeLimit    限制预测中的树的数量;默认为0(使用所有树)。
     * @return 预测结果
     */
    @throws(classOf[XGBoostError])
    def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0):
        Array[Array[Float]] = {
      booster.predict(data.jDMatrix, outPutMargin, treeLimit)
    }

    /**
     * 预测叶子索引
     *
     * @param data      存储输入数据的dmatrix
     * @param treeLimit 限制预测中的树的数量;默认为0(使用所有树)。
     * @return 预测结果
     * @throws XGBoostError 原生错误
     */
    @throws(classOf[XGBoostError])
    def predictLeaf(data: DMatrix, treeLimit: Int = 0): Array[Array[Float]] = {
      booster.predictLeaf(data.jDMatrix, treeLimit)
    }

    /**
      * 输出对给定数据的预测的特征贡献
      *
      * @param data      存储输入数据的dmatrix
      * @param treeLimit 限制预测中的树的数量;默认为0(使用所有树)。
      * @return 特征贡献和偏差
      * @throws XGBoostError 原生错误
      */
    @throws(classOf[XGBoostError])
    def predictContrib(data: DMatrix, treeLimit: Int = 0): Array[Array[Float]] = {
      booster.predictContrib(data.jDMatrix, treeLimit)
    }

    /**
     * 将模型保存到指定的路径
     *
     * @param modelPath 模型路径
     */
    @throws(classOf[XGBoostError])
    def saveModel(modelPath: String): Unit = {
      booster.saveModel(modelPath)
    }

    /**
      * 将模型保存到输出流
      *
      * @param out 输出流
      */
    @throws(classOf[XGBoostError])
    def saveModel(out: java.io.OutputStream): Unit = {
      booster.saveModel(out)
    }

    /**
     * 将模型以字符串数组形式转储
     *
     * @param featureMap 特征映射文件
     * @param withStats  是否输出拆分统计信息。
     */
    @throws(classOf[XGBoostError])
    def getModelDump(featureMap: String = null, withStats: Boolean = false, format: String = "text")
      : Array[String] = {
      booster.getModelDump(featureMap, withStats, format)
    }

    /**
      * 使用指定的特征名称将模型转储为字符串数组。
      *
      * @param featureNames 特征名字数组。
      */
    @throws(classOf[XGBoostError])
    def getModelDump(featureNames: Array[String]): Array[String] = {
      booster.getModelDump(featureNames, false, "text")
    }

    def getModelDump(featureNames: Array[String], withStats: Boolean, format: String)
      : Array[String] = {
      booster.getModelDump(featureNames, withStats, format)
    }

    /**
     * 基于权重(分割数)获取每个特征的重要性
     *
     * @return featureScoreMap  键:特征索引,值:特征重要性分数
     */
    @throws(classOf[XGBoostError])
    def getFeatureScore(featureMap: String = null): mutable.Map[String, Integer] = {
      booster.getFeatureScore(featureMap).asScala
    }

    /**
      * 基于权重(分割数)获取每个特征的重要性,使用指定的特征名称。
      *
      * @return featureScoreMap  键:特征名字,值:特征重要性分数
      */
    @throws(classOf[XGBoostError])
    def getFeatureScore(featureNames: Array[String]): mutable.Map[String, Integer] = {
      booster.getFeatureScore(featureNames).asScala
    }

    /**
      * 根据信息增益或覆盖率获取每个特征的重要性
      * 支持的类型:["gain", "cover", "total_gain", "total_cover"]
      *
      * @return featureScoreMap  键:特征索引,值:特征重要性分数
      */
    @throws(classOf[XGBoostError])
    def getScore(featureMap: String, importanceType: String): Map[String, Double] = {
      Map(booster.getScore(featureMap, importanceType)
          .asScala.mapValues(_.doubleValue).toSeq: _*)
    }

    /**
      * 根据信息增益或覆盖率获取每个特征的重要性,使用指定的特征名称。
      * 支持的类型:["gain", "cover", "total_gain", "total_cover"]
      *
      * @return featureScoreMap  键:特征名字,值:特征重要性分数
      */
    @throws(classOf[XGBoostError])
    def getScore(featureNames: Array[String], importanceType: String): Map[String, Double] = {
      Map(booster.getScore(featureNames, importanceType)
          .asScala.mapValues(_.doubleValue).toSeq: _*)
    }

    def getVersion: Int = booster.getVersion

    def toByteArray: Array[Byte] = {
      booster.toByteArray
    }

    /**
      * 在不再需要时释放Booster
      */
    def dispose: Unit = {
      booster.dispose()
    }

    override def finalize(): Unit = {
      super.finalize()
      dispose
    }

    override def write(kryo: Kryo, output: Output): Unit = {
      kryo.writeObject(output, booster)
    }

    override def read(kryo: Kryo, input: Input): Unit = {
      booster = kryo.readObject(input, classOf[JBooster])
    }
}

上述代码是Booster类的一些方法和函数的定义,提供了对XGBoost模型的预测、保存和加载、特征重要性评估等功能。具体包括以下方法:

  • predict: 对给定数据进行预测。
  • predictLeaf: 预测叶子索引。
  • predictContrib: 输出对给定数据的预测的特征贡献。
  • saveModel: 将模型保存到指定路径或输出流。
  • getModelDump: 将模型转储为字符串数组。
  • getFeatureScore: 获取每个特征的权重(分割数)重要性评估结果。
  • getScore: 根据信息增益或覆盖率获取每个特征的重要性评估结果。
  • getVersion: 获取XGBoost的版本号。
  • toByteArray: 将Booster对象转换为字节数组。
  • dispose: 释放Booster对象占用的资源。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BigDataMLApplication

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值