Spark xgboost原理源码分析
文章目录
- Spark xgboost原理源码分析
-
- 方法总结
- 中文源码
-
- XGBoostClassificationModel
- BoosterParams
- Booster
方法总结
该代码是XGBoost分类模型的实现,它继承了ProbabilisticClassificationModel类,并实现了XGBoostClassifierParams、InferenceParams、MLWritable和Serializable接口。主要包括以下几个部分:
-
构造函数:有两个构造函数,一个是带参数的构造函数,用于初始化模型的uid、numClasses和_booster;另一个是无参构造函数,在内部调用带参数的构造函数,默认设置numClasses为2。
-
nativeBooster方法:返回模型的_native booster实例,用于调用底层API。
-
trainingSummary变量和summary方法:trainingSummary保存训练摘要信息,summary方法用于获取模型在训练集上的摘要信息。
-
一些设置方法:包括设置leafPredictionCol、contribPredictionCol、treeLimit、missing、allowZeroForMissingValue和inferBatchSize等参数。
-
predict方法:用于对单个实例进行预测,根据numClasses的取值情况,返回相应的预测结果。
-
predictRaw方法和raw2probabilityInPlace方法:这两个方法并没有被使用,只是为了通过编译器的检查。
-
transformInternal方法:用于对数据集进行转换,生成预测结果的DataFrame。
-
produceResultIterator方法:用于生成预测结果的迭代器。
-
generateResultSchema方法:根据固定的Schema生成最终的结果Schema。
-
producePredictionItrs方法:生成预测结果的迭代器数组。
-
transform方法:对数据集进行转换,生成包含预测结果的DataFrame。
-
copy方法:复制模型。
-
write方法:返回XGBoostClassificationModelWriter实例,用于保存模型。
总体来说,该代码实现了XGBoost分类模型的训练、预测和保存等功能,并提供了一些设置方法和获取摘要信息的方法。
中文源码
XGBoostClassificationModel
import ml.dmlc.xgboost4j.scala.DMatrix
import org.apache.spark.ml.classification.{
ProbabilisticClassificationModel, ProbabilisticClassifier}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util.MLWritable
import org.apache.spark.sql.Dataset
class XGBoostClassificationModel private[xgboost](
override val uid: String,
private[xgboost] val numClasses: Int,
@transient private[xgboost] val _booster: Booster)
extends ProbabilisticClassificationModel[Vector, XGBoostClassificationModel]
with XGBoostClassifierParams with InferenceParams with MLWritable with Serializable {
// 构造函数,用于初始化模型的uid、numClasses和_booster
def this(uid: String, _booster: Booster) =
this(uid, 2, _booster)
// 返回_native booster实例,用于调用底层API
private[xgboost] def nativeBooster: Booster = _booster
var trainingSummary: Option[XGBoostTrainingSummary] = None
// 获取训练集上的摘要信息
def summary: XGBoostTrainingSummary = trainingSummary.getOrElse {
throw new RuntimeException("No training summary available for this XGBoostClassificationModel")
}
// 设置leafPredictionCol参数
def setLeafPredictionCol(value: String): this.type = set(leafPredictionCol, value)
// 设置contribPredictionCol参数
def setContribPredictionCol(value: String): this.type = set(contribPredictionCol, value)
// 设置treeLimit参数
def setTreeLimit(value: Int): this.type = set(treeLimit, value)
// 设置missing参数
def setMissing(value: Float): this.type = set(missing, value)
// 设置allowZeroForMissingValue参数
def setAllowZeroForMissingValue(value: Boolean): this.type = set(allowZeroForMissingValue, value)
// 设置inferBatchSize参数
def setInferBatchSize(value: Int): this.type = set(inferBatchSize, value)
// 对单个实例进行预测,返回预测结果
override protected def predict(features: Vector): Double = {
val dMatrix = new DMatrix(features.asML, missingValue = getMissing)
val prediction = nativeBooster.predict(dMatrix, outPutMargin = true)
if (numClasses == 2) {
prediction.head
} else {
prediction.zipWithIndex.maxBy(_._1)._2.toDouble
}
}
// 以下两个方法并没有被使用,只是为了通过编译器的检查
private def predictRaw(features: Vector): Vector = {
throw new RuntimeException("XGBoostClassificationModel doesn't support predictRaw")
}
private def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
throw new RuntimeException("XGBoostClassificationModel doesn't support raw2probabilityInPlace")
}
// 对数据集进行转换,生成预测结果的DataFrame
override protected def transformInternal(dataset: Dataset[_]): DataFrame = {
val outputSchema = generateResultSchema(dataset.schema)
producePredictionItrs(dataset).zipWithIndex.foldLeft(dataset.toDF()) {
case (df, (it, index)) =>
val resultColName = if (getOutputCols.length > 1) {
getOutputCols(index)
} else {
getOutputCol
}
df.withColumn(resultColName, it)
}.select(getInputCols.map(dataset.col) ++ getOutputCols: _*).toDF().transform {
df =>
transformSchema(df.schema, logging = true)
df
}
}
// 生成预测结果的迭代器
private def produceResultIterator(
originalRowItr: Iterator[Row],
rawPredictionItr: Iterator[Row],
probabilityItr: Iterator[Row],
predLeafItr: Iterator[Row],
predContribItr: Iterator[Row]): Iterator[Row] = {
// the following implementation is to be improved
// 检查是否定义了leafPredictionCol和contribPredictionCol参数,并且它们非空
if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
isDefined(contribPredictionCol) && $(contribPredictionCol).nonEmpty) {
// 如果都定义了,则将原始数据、原始预测、概率、叶子节点预测和贡献度合并为一行数据
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).zip(predContribItr).
map {
case ((((originals: Row, rawPrediction: Row), probability: Row), leaves: Row),
contribs: Row) =>
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves.toSeq ++
contribs.toSeq)
}
} else if (isDefined(leafPredictionCol) && $(leafPredictionCol).nonEmpty &&
(!isDefined(contribPredictionCol) || $(contribPredictionCol).isEmpty)) {
// 如果只定义了leafPredictionCol参数,则将原始数据、原始预测、概率和叶子节点预测合并为一行数据
originalRowItr.zip(rawPredictionItr).zip(probabilityItr).zip(predLeafItr).
map {
case (((originals: Row, rawPrediction: Row), probability: Row), leaves: Row) =>
Row.fromSeq(originals.toSeq ++ rawPrediction.toSeq ++ probability.toSeq ++ leaves

本文对Spark XGBoost进行原理与源码分析。介绍了XGBoost分类模型代码,包括构造函数、nativeBooster方法等,实现了训练、预测和保存功能。还分析了Booster类方法,如predict、saveModel等,提供预测、保存加载及特征重要性评估功能。
最低0.47元/天 解锁文章
4003

被折叠的 条评论
为什么被折叠?



