Spark xgboost原理源码分析
文章目录
- Spark xgboost原理源码分析
- 方法总结
- 中文源码
- XGBoostClassificationModel
- BoosterParams
- Booster
方法总结
该代码是XGBoost分类模型的实现,它继承了ProbabilisticClassificationModel类,并实现了XGBoostClassifierParams、InferenceParams、MLWritable和Serializable接口。主要包括以下几个部分:
-
构造函数:有两个构造函数,一个是带参数的构造函数,用于初始化模型的uid、numClasses和_booster;另一个是无参构造函数,在内部调用带参数的构造函数,默认设置numClasses为2。
-
nativeBooster方法:返回模型的_native booster实例,用于调用底层API。
-
trainingSummary变量和summary方法:trainingSummary保存训练摘要信息,summary方法用于获取模型在训练集上的摘要信息。
-
一些设置方法:包括设置leafPredictionCol、contribPredictionCol、treeLimit、missing、allowZeroForMissingValue和inferBatchSize等参数。
-
predict方法:用于对单个实例进行预测,根据numClasses的取值情况,返回相应的预测结果。
-
predictRaw方法和raw2probabilityInPlace方法:这两个方法并没有被使用,只是为了通过编译器的检查。
-
transformInternal方法:用于对数据集进行转换,生成预测结果的DataFrame。
-
produceResultIterator方法:用于生成预测结果的迭代器。
-
generateResultSchema方法:根据固定的Schema生成最终的结果Schema。
-
producePredictionItrs方法:生成预测结果的迭代器数组。
-
transform方法:对数据集进行转换,生成包含预测结果的DataFrame。
-
copy方法:复制模型。
-
write方法:返回XGBoostClassificationModelWriter实例,用于保存模型。
总体来说,该代码实现了XGBoost分类模型的训练、预测和保存等功能,并提供了一些设置方法和获取摘要信息的方法。
中文源码
XGBoostClassificationModel
import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.ml.classification.{ProbabilisticClassificationModel, ProbabilisticClassifier}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util.MLWritable
import org.apache.spark.sql.Dataset
class XGBoostClassificationModel private[xgboost](
override val uid: String,
private[xgboost] val numClasses: Int,
@transient private[xgboost] val _booster: Booster)
extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
with XGBoostClassifierParams with InferenceParams with MLWritable with Serializable {
// 构造函数,用于初始化模型的uid、numClasses和_booster
def this(uid: String, _booster: Booster) =
this(uid, 2, _booster)
// 返回_native booster实例,用于调用底层API
private[xgboost] def nativeBooster: Booster = _booster
var trainingSummary: Option[XGBoostTrainingSummary] = None
// 获取训练集上的摘要信息
def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
throw new RuntimeException("No training summary available for this XGBoostClassificationModel")
}
// 设置leafPredictionCol参数
def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)
// 设置contribPredictionCol参数
def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value)
// 设置treeLimit参数
def setTreeLimit(value: Int): this.type = set(treeLimit, value)
// 设置missing参数
def setMissing(value: Float): this.type = set(missing, value)
// 设置allowZeroForMissingValue参数
def setAllowZeroForMissingValue(value: Boolean): this.type = set(allowZeroForMissingValue, value)
// 设置inferBatchSize参数
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
// 对单个实例进行预测,返回预测结果
override protected def predict(features: Vector): Double = {
val dMatrix = new DMatrix(features.asML, missingValue = getMissing)
val prediction = nativeBooster.predict(dMatrix, outPutMargin = true)
if (numClasses == 2) {
prediction.head
} else {
prediction.zipWithIndex.maxBy(_._1)._2.toDouble
}
}
// 以下两个方法并没有被使用,只是为了通过编译器的检查
private def predictRaw(features: Vector): Vector = {
throw new RuntimeException("XGBoostClassificationModel doesn't support predictRaw")
}
private def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
throw new RuntimeException("XGBoostClassificationModel doesn't support raw2probabilityInPlace")
}
// 对数据集进行转换,生成预测结果的DataFrame
override protected def transformInternal(dataset: Dataset[_]): DataFrame = {
val outputSchema = generateResultSchema(dataset.schema)
producePredictionItrs(dataset).zipWithIndex.foldLeft(dataset.toDF()) {
case (df, (it, index)) =>
val resultColName = if (getOutputCols.length > 1) {
getOutputCols(index)
} else {
getOutputCol
}
df.withColumn(resultColName, it)
}.select(getInputCols.map(dataset.col) ++ getOutputCols: _*).toDF().transform {
df =>
transformSchema(df.schema, logging = true)
df
}
}
// 生成预测结果的迭代器
private def produceResultIterator(
originalRowItr: Iterator[Row],
rawPredictionItr: Iterator[Row],
probabilityItr: Iterator[Row],
predLeafItr: Iterator[Row],
predContribItr: Iterator[Row]): Iterator[Row] = {
// the following implementation is to be improved
// 检查是否定义了leafPredictionCol和contribPredictionCol参数,并且它们非空
if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
// 如果都定义了,则将原始数据、原始预测、概率、叶子节点预测和贡献度合并为一行数据
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).zip(predContribItr).
map { case ((((originals: Row, rawPrediction: Row), probability: Row), leaves: Row),
contribs: Row) =>
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq ++
contribs.toSeq)
}
} else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
(!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) {
// 如果只定义了leafPredictionCol参数,则将原始数据、原始预测、概率和叶子节点预测合并为一行数据
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).
map { case (((originals: Row, rawPrediction: Row), probability: Row), leaves: Row) =>
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq)
}
} else if ((!isDefined(leafPredictionCol) || $(leafPredictionCol).isEmpty) &&
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
// 如果只定义了contribPredictionCol参数,则将原始数据、原始预测、概率和贡献度合并为一行数据
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predContribItr).
map { case (((originals: Row, rawPrediction: Row), probability: Row), contribs: Row) =>
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ contribs.toSeq)
}
} else {
// 如果都未定义,则将原始数据、原始预测和概率合并为一行数据
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).map {
case ((originals: Row, rawPrediction: Row), probability: Row) =>
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq)
}
}
}
private def generateResultSchema(fixedSchema: StructType): StructType = {
var resultSchema = fixedSchema
// 检查是否定义了leafPredictionCol参数,并且非空
if (isDefined(leafPredictionCol)) {
// 添加leafPredictionCol字段到resultSchema中,数据类型为ArrayType(FloatType)
resultSchema = resultSchema.add(StructField(name = $(leafPredictionCol), dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
// 检查是否定义了contribPredictionCol参数,并且非空
if (isDefined(contribPredictionCol)) {
// 添加contribPredictionCol字段到resultSchema中,数据类型为ArrayType(FloatType)
resultSchema = resultSchema.add(StructField(name = $(contribPredictionCol), dataType =
ArrayType(FloatType, containsNull = false), nullable = false))
}
resultSchema
}
private def producePredictionItrs(broadcastBooster: Broadcast[Booster], dm: DMatrix):
Array[Iterator[Row]] = {
// 使用广播变量的Booster进行预测,返回原始预测值的迭代器
val rawPredictionItr = {
broadcastBooster.value.predict(dm, outPutMargin = true, $(treeLimit)).
map(Row(_)).iterator
}
// 使用广播变量的Booster进行预测,返回概率的迭代器
val probabilityItr = {
broadcastBooster.value.predict(dm, outPutMargin = false, $(treeLimit)).
map(Row(_)).iterator
}
// 如果定义了叶子节点预测列,则使用广播变量的Booster进行预测叶子节点,返回叶子节点预测值的迭代器
val predLeafItr = {
if (isDefined(leafPredictionCol)) {
broadcastBooster.value.predictLeaf(dm, $(treeLimit)).map(Row(_)).iterator
} else {
Iterator()
}
}
// 如果定义了贡献度预测列,则使用广播变量的Booster进行预测贡献度,返回贡献度预测值的迭代器
val predContribItr = {
if (isDefined(contribPredictionCol)) {
broadcastBooster.value.predictContrib(dm, $(treeLimit)).map(Row(_)).iterator
} else {
Iterator()
}
}
// 返回包含以上四个迭代器的数组
Array(rawPredictionItr, probabilityItr, predLeafItr, predContribItr)
}
override def transform(dataset: Dataset[_]): DataFrame = {
// 确保schema匹配,打印日志
transformSchema(dataset.schema, logging = true)
if (isDefined(thresholds)) {
// 如果定义了阈值,则要求阈值的长度与类别数相同
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".transform() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
// 只输出选定的列
// 这里有点复杂,因为它试图避免重复计算
var outputData = transformInternal(dataset)
var numColsOutput = 0
// 定义UDF来处理原始预测值
val rawPredictionUDF = udf { rawPrediction: mutable.WrappedArray[Float] =>
val raw = rawPrediction.map(_.toDouble).toArray
val rawPredictions = if (numClasses == 2) Array(-raw(0), raw(0)) else raw
Vectors.dense(rawPredictions)
}
// 定义UDF来处理概率
val probabilityUDF = udf { probability: mutable.WrappedArray[Float] =>
val prob = probability.map(_.toDouble).toArray
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
Vectors.dense(probabilities)
}
// 定义UDF来处理预测值
val predictUDF = udf { probability: mutable.WrappedArray[Float] =>
// 从XGBoost的概率转换为MLlib的预测值
val prob = probability.map(_.toDouble).toArray
val probabilities = if (numClasses == 2) Array(1.0 - prob(0), prob(0)) else prob
probability2prediction(Vectors.dense(probabilities))
}
// 如果定义了原始预测列,则将原始预测值应用UDF并添加到输出数据中
if ($(rawPredictionCol).nonEmpty) {
outputData = outputData
.withColumn(getRawPredictionCol, rawPredictionUDF(col(_rawPredictionCol)))
numColsOutput += 1
}
// 如果定义了概率列,则将概率应用UDF并添加到输出数据中
if ($(probabilityCol).nonEmpty) {
outputData = outputData
.withColumn(getProbabilityCol, probabilityUDF(col(_probabilityCol)))
numColsOutput += 1
}
// 如果定义了预测列,则将概率应用UDF并添加到输出数据中
if ($(predictionCol).nonEmpty) {
outputData = outputData
.withColumn($(predictionCol), predictUDF(col(_probabilityCol)))
numColsOutput += 1
}
// 如果没有输出列被设置,则打印警告日志
if (numColsOutput == 0) {
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
" since no output columns were set.")
}
// 返回输出数据,并丢弃原始预测列和概率列
outputData
.toDF
.drop(col(_rawPredictionCol))
.drop(col(_probabilityCol))
}
override def copy(extra: ParamMap): XGBoostClassificationModel = {
// 复制模型并设置额外的参数
val newModel = copyValues(new XGBoostClassificationModel(uid, numClasses, _booster), extra)
newModel.setSummary(summary).setParent(parent)
}
override def write: MLWriter =
// 返回XGBoostClassificationModelWriter实例用于写入模型
new XGBoostClassificationModel.XGBoostClassificationModelWriter(this)
}
BoosterParams
private[spark] trait BoosterParams extends Params {
/**
* 步长缩减系数,用于防止过拟合。每次提升步骤之后,
* 可以直接得到新特征的权重,eta实际上会缩小特征权重,
* 使提升过程更加保守。取值范围为[0, 1],默认值为0.3。
*/
final val eta = new DoubleParam(this, "eta", "步长缩减系数,用于防止过拟合。" +
"每次提升步骤之后,可以直接得到新特征的权重,并且eta实际上会缩小特征权重," +
"使提升过程更加保守。",
(value: Double) => value >= 0 && value <= 1)
final def getEta: Double = $(eta)
/**
* 叶子节点进一步划分所需的最小损失减少量。该值越大,算法越保守。
* 取值范围为[0, Double.MaxValue],默认值为0。
*/
final val gamma = new DoubleParam(this, "gamma", "叶子节点进一步划分所需的最小损失减少量。" +
"该值越大,算法越保守。",
(value: Double) => value >= 0)
final def getGamma: Double = $(gamma)
/**
* 树的最大深度,增加此值会使模型更复杂/更容易过拟合。
* 取值范围为[1, Int.MaxValue],默认值为6。
*/
final val maxDepth = new IntParam(this, "maxDepth", "树的最大深度,增加此值会使模型更复杂/更容易过拟合。",
(value: Int) => value >= 0)
final def getMaxDepth: Int = $(maxDepth)
/**
* 要添加的最大节点数。仅在设置了grow_policy=lossguide时相关。
* 默认值为0。
*/
final val maxLeaves = new IntParam(this, "maxLeaves",
"要添加的最大节点数。仅在设置了grow_policy=lossguide时相关。",
(value: Int) => value >= 0)
final def getMaxLeaves: Int = $(maxLeaves)
/**
* 子节点中所需的最小实例权重(Hessian)总和。如果树的划分导致叶子节点的实例权重总和小于min_child_weight,
* 则构建过程将停止进一步划分。在线性回归模式下,这简单地对应于每个节点所需的最小实例数。
* 该值越大,算法越保守。取值范围为[0, Double.MaxValue],默认值为1。
*/
final val minChildWeight = new DoubleParam(this, "minChildWeight", "子节点中所需的最小实例权重(Hessian)总和。" +
"如果树的划分导致叶子节点的实例权重总和小于min_child_weight,则构建过程将停止进一步划分。" +
"在线性回归模式下,这简单地对应于每个节点所需的最小实例数。" +
"该值越大,算法越保守。",
(value: Double) => value >= 0)
final def getMinChildWeight: Double = $(minChildWeight)
/**
* 最大增量步长限制,用于控制每棵树权重估计的最大变化值。如果该值设置为0,表示没有约束。
* 如果设置为正值,则可以帮助使更新步骤更加保守。通常情况下,不需要使用该参数,
* 但在类别极度不平衡的逻辑回归中可能会有所帮助。将其设置为1-10之间的值可以帮助控制更新。
* [默认值=0] 范围:[0, Double.MaxValue]
*/
final val maxDeltaStep = new DoubleParam(this, "maxDeltaStep", "最大增量步长限制,用于控制每棵树权重估计的最大变化值。如果该值设置为0,表示没有约束。如果设置为正值,则可以帮助使更新步骤更加保守。通常情况下,不需要使用该参数,但在类别极度不平衡的逻辑回归中可能会有所帮助。将其设置为1-10之间的值可以帮助控制更新",
(value: Double) => value >= 0)
final def getMaxDeltaStep: Double = $(maxDeltaStep)
/**
* 训练实例的子采样比例。将其设置为0.5表示XGBoost随机收集一半的数据实例用于构建树,以防止过拟合。
* [默认值=1] 范围:(0,1]
*/
final val subsample = new DoubleParam(this, "subsample", "训练实例的子采样比例。将其设置为0.5表示XGBoost随机收集一半的数据实例用于构建树,以防止过拟合。",
(value: Double) => value <= 1 && value > 0)
final def getSubsample: Double = $(subsample)
/**
* 构建每棵树时列的子采样比例。[默认值=1] 范围:(0,1]
*/
final val colsampleBytree = new DoubleParam(this, "colsampleBytree", "构建每棵树时列的子采样比例。",
(value: Double) => value <= 1 && value > 0)
final def getColsampleBytree: Double = $(colsampleBytree)
/**
* 每个级别中每次分割时列的子采样比例。[默认值=1] 范围:(0,1]
*/
final val colsampleBylevel = new DoubleParam(this, "colsampleBylevel", "每个级别中每次分割时列的子采样比例。",
(value: Double) => value <= 1 && value > 0)
final def getColsampleBylevel: Double = $(colsampleBylevel)
/**
* 权重的L2正则化项,增加此值会使模型更加保守。[默认值=1]
*/
final val lambda = new DoubleParam(this, "lambda", "权重的L2正则化项,增加此值会使模型更加保守。",
(value: Double) => value >= 0)
final def getLambda: Double = $(lambda)
/**
* 权重的L1正则化项,增加此值会使模型更加保守。[默认值=0]
*/
final val alpha = new DoubleParam(this, "alpha", "权重的L1正则化项,增加此值会使模型更加保守。",
(value: Double) => value >= 0)
final def getAlpha: Double = $(alpha)
/**
* XGBoost中使用的树构建算法。选项:{'auto', 'exact', 'approx'}
* [默认值='auto']
*/
final val treeMethod = new Param[String](this, "treeMethod",
"XGBoost中使用的树构建算法。选项:{'auto', 'exact', 'approx', 'hist'}",
(value: String) => BoosterParams.supportedTreeMethods.contains(value))
final def getTreeMethod: String = $(treeMethod)
/**
* 快速直方图算法的增长策略
*/
final val growPolicy = new Param[String](this, "growPolicy",
"控制新节点添加到树中的方式。仅当tree_method设置为hist时支持。选项:depthwise(从根节点开始拆分),lossguide(选择损失变化最大的节点进行拆分)。",
(value: String) => BoosterParams.supportedGrowthPolicies.contains(value))
final def getGrowPolicy: String = $(growPolicy)
/**
* 直方图中的最大bin数
*/
final val maxBins = new IntParam(this, "maxBin", "直方图中的最大bin数",
(value: Int) => value > 0)
final def getMaxBins: Int = $(maxBins)
/**
* 仅用于近似贪心算法。粗略地等同于O(1/sketch_eps)个bin。与直接选择bin数相比,具有理论保证和草图准确性。
* [默认值=0.03] 范围:(0, 1)
*/
final val sketchEps = new DoubleParam(this, "sketchEps",
"仅用于近似贪心算法。粗略地等同于O(1/sketch_eps)个bin。与直接选择bin数相比,具有理论保证和草图准确性。",
(value: Double) => value < 1 && value > 0)
final def getSketchEps: Double = $(sketchEps)
/**
* 控制正负权重的平衡,对于不平衡的类别很有用。一个典型的值可以考虑:负样本总数 / 正样本总数。[默认值=1]
*/
final val scalePosWeight = new DoubleParam(this, "scalePosWeight", "控制正负权重的平衡,对于不平衡的类别很有用。一个典型的值可以考虑:负样本总数 / 正样本总数")
final def getScalePosWeight: Double = $(scalePosWeight)
// Dart增强器
/**
* Dart增强器的参数。采样算法的类型。"uniform":均匀选择被丢弃的树。"weighted":按权重比例选择被丢弃的树。[默认值="uniform"]
*/
final val sampleType = new Param[String](this, "sampleType", "采样算法的类型,选项:{'uniform', 'weighted'}",
(value: String) => BoosterParams.supportedSampleType.contains(value))
final def getSampleType: String = $(sampleType)
/**
* Dart增强器的参数。归一化算法的类型,选项:{'tree', 'forest'}。[默认值="tree"]
*/
final val normalizeType = new Param[String](this, "normalizeType", "归一化算法的类型,选项:{'tree', 'forest'}",
(value: String) => BoosterParams.supportedNormalizeType.contains(value))
final def getNormalizeType: String = $(normalizeType)
/**
* Dart增强器的参数。dropout率。[默认值=0.0] 范围:[0.0, 1.0]
*/
final val rateDrop = new DoubleParam(this, "rateDrop", "dropout率",
(value: Double) => value >= 0 && value <= 1)
final def getRateDrop: Double = $(rateDrop)
/**
* Dart增强器的参数。dropout率。[默认值=0.0] 范围:[0.0, 1.0]
*/
final val rateDrop = new DoubleParam(this, "rateDrop", "dropout率",
(value: Double) => value >= 0 && value <= 1)
final def getRateDrop: Double = $(rateDrop)
/**
* Dart增强器的参数。跳过dropout的概率。如果跳过了一个dropout,则以与gbtree相同的方式添加新树。[默认值=0.0] 范围:[0.0, 1.0]
*/
final val skipDrop = new DoubleParam(this, "skipDrop", "跳过dropout的概率。如果跳过了一个dropout,则以与gbtree相同的方式添加新树。",
(value: Double) => value >= 0 && value <= 1)
final def getSkipDrop: Double = $(skipDrop)
// 线性增强器
/**
* 线性增强器的参数。对偏差的L2正则化项,默认为0(没有L1正则化项,因为它不重要)
*/
final val lambdaBias = new DoubleParam(this, "lambdaBias", "对偏差的L2正则化项,默认为0(没有L1正则化项,因为它不重要)",
(value: Double) => value >= 0)
final def getLambdaBias: Double = $(lambdaBias)
final val treeLimit = new IntParam(this, name = "treeLimit",
doc = "在预测中使用的树的数量,默认为0(使用所有树)。")
final def getTreeLimit: Int = $(treeLimit)
final val monotoneConstraints = new Param[String](this, name = "monotoneConstraints",
doc = "特征数量的列表,1表示单调递增,-1表示单调递减,0表示没有约束。如果长度小于特征数量,则用0填充")
final def getMonotoneConstraints: String = $(monotoneConstraints)
final val interactionConstraints = new Param[String](this,
name = "interactionConstraints",
doc = "表示允许交互的约束的交互。约束必须以嵌套列表的形式指定,例如[[0, 1], [2, 3, 4]],其中每个内部列表是允许相互作用的特征的索引组。详细信息请参见教程")
setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6,
minChildWeight -> 1, maxDeltaStep -> 0,
growPolicy -> "depthwise", maxBins -> 16,
subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1,
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0, treeLimit -> 0)
}
private[spark] object BoosterParams {
val supportedBoosters = HashSet("gbtree", "gblinear", "dart")
val supportedTreeMethods = HashSet("auto", "exact", "approx", "hist")
val supportedGrowthPolicies = HashSet("depthwise", "lossguide")
val supportedSampleType = HashSet("uniform", "weighted")
val supportedNormalizeType = HashSet("tree", "forest")
}
Booster
/**
* xgboost的增强器,这是一个支持交互式构建XGBoost模型的模型API
*
* 开发者警告:一个Java Booster不能被多个Scala Booster共享
* @param booster Java Booster对象
*/
class Booster private[xgboost4j](private[xgboost4j] var booster: JBooster)
extends Serializable with KryoSerializable {
/**
* 获取存储在Booster中的属性,并以Map形式返回。
*
* @return 包含属性对的Map
*/
@throws(classOf[XGBoostError])
def getAttrs: Map[String, String] = {
booster.getAttrs.asScala.toMap
}
/**
* 从Booster中获取属性。
*
* @param key 属性名
* @return 属性值
*/
@throws(classOf[XGBoostError])
def getAttr(key: String): String = {
booster.getAttr(key)
}
/**
* 将属性设置到Booster中。
*
* @param key 属性名
* @param value 属性值
*/
@throws(classOf[XGBoostError])
def setAttr(key: String, value: String): Unit = {
booster.setAttr(key, value)
}
/**
* 设置属性。
*
* @param params 属性的键值对Map
*/
@throws(classOf[XGBoostError])
def setAttrs(params: Map[String, String]): Unit = {
booster.setAttrs(params.asJava)
}
/**
* 将参数设置到Booster中。
*
* @param key 参数名
* @param value 参数值
*/
@throws(classOf[XGBoostError])
def setParam(key: String, value: AnyRef): Unit = {
booster.setParam(key, value)
}
/**
* 设置参数。
*
* @param params 参数的键值对Map
*/
@throws(classOf[XGBoostError])
def setParams(params: Map[String, AnyRef]): Unit = {
booster.setParams(params.asJava)
}
/**
* 更新(一次迭代)
*
* @param dtrain 训练数据
* @param iter 当前迭代次数
*/
@throws(classOf[XGBoostError])
def update(dtrain: DMatrix, iter: Int): Unit = {
booster.update(dtrain.jDMatrix, iter)
}
/**
* 使用自定义目标函数进行更新
*
* @param dtrain 训练数据
* @param obj 自定义目标函数类
*/
@throws(classOf[XGBoostError])
def update(dtrain: DMatrix, obj: ObjectiveTrait): Unit = {
booster.update(dtrain.jDMatrix, obj)
}
/**
* 使用给定的梯度和海森矩阵进行更新
*
* @param dtrain 训练数据
* @param grad 梯度的一阶导数
* @param hess 梯度的二阶导数
*/
@throws(classOf[XGBoostError])
def boost(dtrain: DMatrix, grad: Array[Float], hess: Array[Float]): Unit = {
booster.boost(dtrain.jDMatrix, grad, hess)
}
/**
* 对给定的矩阵进行评估。
*
* @param evalMatrixs 评估矩阵
* @param evalNames 评估名称,用于检查结果
* @param iter 当前评估迭代次数
* @return 评估信息
*/
@throws(classOf[XGBoostError])
def evalSet(evalMatrixs: Array[DMatrix], evalNames: Array[String], iter: Int): String = {
booster.evalSet(evalMatrixs.map(_.jDMatrix), evalNames, iter)
}
/**
* 使用给定的数据进行预测
*
* @param data 存储输入数据的dmatrix
* @param outPutMargin 是否输出未转换的原始边际值。
* @param treeLimit 限制预测中的树的数量;默认为0(使用所有树)。
* @return 预测结果
*/
@throws(classOf[XGBoostError])
def predict(data: DMatrix, outPutMargin: Boolean = false, treeLimit: Int = 0):
Array[Array[Float]] = {
booster.predict(data.jDMatrix, outPutMargin, treeLimit)
}
/**
* 预测叶子索引
*
* @param data 存储输入数据的dmatrix
* @param treeLimit 限制预测中的树的数量;默认为0(使用所有树)。
* @return 预测结果
* @throws XGBoostError 原生错误
*/
@throws(classOf[XGBoostError])
def predictLeaf(data: DMatrix, treeLimit: Int = 0): Array[Array[Float]] = {
booster.predictLeaf(data.jDMatrix, treeLimit)
}
/**
* 输出对给定数据的预测的特征贡献
*
* @param data 存储输入数据的dmatrix
* @param treeLimit 限制预测中的树的数量;默认为0(使用所有树)。
* @return 特征贡献和偏差
* @throws XGBoostError 原生错误
*/
@throws(classOf[XGBoostError])
def predictContrib(data: DMatrix, treeLimit: Int = 0): Array[Array[Float]] = {
booster.predictContrib(data.jDMatrix, treeLimit)
}
/**
* 将模型保存到指定的路径
*
* @param modelPath 模型路径
*/
@throws(classOf[XGBoostError])
def saveModel(modelPath: String): Unit = {
booster.saveModel(modelPath)
}
/**
* 将模型保存到输出流
*
* @param out 输出流
*/
@throws(classOf[XGBoostError])
def saveModel(out: java.io.OutputStream): Unit = {
booster.saveModel(out)
}
/**
* 将模型以字符串数组形式转储
*
* @param featureMap 特征映射文件
* @param withStats 是否输出拆分统计信息。
*/
@throws(classOf[XGBoostError])
def getModelDump(featureMap: String = null, withStats: Boolean = false, format: String = "text")
: Array[String] = {
booster.getModelDump(featureMap, withStats, format)
}
/**
* 使用指定的特征名称将模型转储为字符串数组。
*
* @param featureNames 特征名字数组。
*/
@throws(classOf[XGBoostError])
def getModelDump(featureNames: Array[String]): Array[String] = {
booster.getModelDump(featureNames, false, "text")
}
def getModelDump(featureNames: Array[String], withStats: Boolean, format: String)
: Array[String] = {
booster.getModelDump(featureNames, withStats, format)
}
/**
* 基于权重(分割数)获取每个特征的重要性
*
* @return featureScoreMap 键:特征索引,值:特征重要性分数
*/
@throws(classOf[XGBoostError])
def getFeatureScore(featureMap: String = null): mutable.Map[String, Integer] = {
booster.getFeatureScore(featureMap).asScala
}
/**
* 基于权重(分割数)获取每个特征的重要性,使用指定的特征名称。
*
* @return featureScoreMap 键:特征名字,值:特征重要性分数
*/
@throws(classOf[XGBoostError])
def getFeatureScore(featureNames: Array[String]): mutable.Map[String, Integer] = {
booster.getFeatureScore(featureNames).asScala
}
/**
* 根据信息增益或覆盖率获取每个特征的重要性
* 支持的类型:["gain", "cover", "total_gain", "total_cover"]
*
* @return featureScoreMap 键:特征索引,值:特征重要性分数
*/
@throws(classOf[XGBoostError])
def getScore(featureMap: String, importanceType: String): Map[String, Double] = {
Map(booster.getScore(featureMap, importanceType)
.asScala.mapValues(_.doubleValue).toSeq: _*)
}
/**
* 根据信息增益或覆盖率获取每个特征的重要性,使用指定的特征名称。
* 支持的类型:["gain", "cover", "total_gain", "total_cover"]
*
* @return featureScoreMap 键:特征名字,值:特征重要性分数
*/
@throws(classOf[XGBoostError])
def getScore(featureNames: Array[String], importanceType: String): Map[String, Double] = {
Map(booster.getScore(featureNames, importanceType)
.asScala.mapValues(_.doubleValue).toSeq: _*)
}
def getVersion: Int = booster.getVersion
def toByteArray: Array[Byte] = {
booster.toByteArray
}
/**
* 在不再需要时释放Booster
*/
def dispose: Unit = {
booster.dispose()
}
override def finalize(): Unit = {
super.finalize()
dispose
}
override def write(kryo: Kryo, output: Output): Unit = {
kryo.writeObject(output, booster)
}
override def read(kryo: Kryo, input: Input): Unit = {
booster = kryo.readObject(input, classOf[JBooster])
}
}
上述代码是Booster类的一些方法和函数的定义,提供了对XGBoost模型的预测、保存和加载、特征重要性评估等功能。具体包括以下方法:
- predict: 对给定数据进行预测。
- predictLeaf: 预测叶子索引。
- predictContrib: 输出对给定数据的预测的特征贡献。
- saveModel: 将模型保存到指定路径或输出流。
- getModelDump: 将模型转储为字符串数组。
- getFeatureScore: 获取每个特征的权重(分割数)重要性评估结果。
- getScore: 根据信息增益或覆盖率获取每个特征的重要性评估结果。
- getVersion: 获取XGBoost的版本号。
- toByteArray: 将Booster对象转换为字节数组。
- dispose: 释放Booster对象占用的资源。