Spark ML中RandomForestClassifier RandomForestClassificationModel原理示例源码分析点击这里看全文
文章目录
原理
Spark ML中的随机森林分类器(RandomForestClassifier)是基于集成学习方法的一种分类模型。它由多个决策树组成,每个决策树都是通过对训练数据进行自助采样(bootstrap)和特征随机选择而生成的。
以下是Spark ML中随机森林分类器的工作原理:
-
数据准备:将输入的训练数据划分为若干个随机子样本。对于每个子样本,从原始数据集中有放回地采样相同数量的样本,形成一个新的训练集。同时,对于每个决策树,还会随机选择一部分特征用于构建树。
-
决策树的构建:对于每个子样本和随机选择的特征,使用决策树算法(如ID3、C4.5或CART)构建一个决策树模型。决策树的构建过程包括选择最佳的特征进行节点划分、递归地构建子树,直到达到停止条件(如树的深度达到预设值)。
-
集成学习:将所有构建好的决策树组合成随机森林模型。在分类问题中,每个决策树会根据样本的特征进行预测,并统计最终的类别投票结果。根据多数表决原则,选择票数最多的类别作为随机森林模型的最终预测结果。
-
特征重要性评估:在随机森林中,每个决策树都可以衡量特征的重要性。通过对所有决策树的特征重要性进行平均,得到整个随机森林模型的特征重要性评估。这可以帮助我们了解哪些特征对于分类结果的贡献更大。
-
预测:对于新的输入数据,随机森林模型会将该数据传递给每个决策树进行预测,然后根据决策树的投票结果得出最终的分类结果。
随机森林具有以下优点:
- 可以处理大量的训练数据,并能够处理高维度的特征。
- 对于缺失数据和噪声具有一定的鲁棒性。
- 能够评估特征的重要性,用于特征选择和分析。
- 在训练过程中,可以并行构建多个决策树,加快训练速度。
需要注意的是,随机森林模型的性能和泛化能力与决策树的数量、树的深度、特征选择策略等参数相关。在使用随机森林时,需要根据具体问题和数据集进行参数调优,以获得最佳的分类性能。
方法总结
RandomForestClassifier是Spark ML中用于分类任务的随机森林模型。下面是该类的一些重要方法的总结:
-
fit(dataset: Dataset[_]): RandomForestClassificationModel:使用给定的训练数据集拟合(训练)随机森林模型,并返回一个训练好的RandomForestClassificationModel对象。 -
setFeaturesCol(value: String): RandomForestClassifier:设置输入特征列的名称。 -
setPredictionCol(value: String): RandomForestClassifier:设置预测结果列的名称。 -
setLabelCol(value: String): RandomForestClassifier:设置标签列的名称,即目标变量。 -
setMaxDepth(value: Int): RandomForestClassifier:设置决策树的最大深度。 -
setNumTrees(value: Int): RandomForestClassifier:设置随机森林中决策树的数量。 -
setSubsamplingRate(value: Double): RandomForestClassifier:设置用于训练每个决策树的样本子集的比例。 -
setFeatureSubsetStrategy(value: String): RandomForestClassifier:设置特征子集选择策略,可以是"auto"、“all”、“onethird”、"sqrt"或"log2"之一。 -
setSeed(value: Long): RandomForestClassifier:设置随机数生成器的种子。 -
setImpurity(value: String): RandomForestClassifier:设置不纯度度量方法,可以是"entropy"(熵)或"gini"(基尼指数)之一。 -
setRawPredictionCol(value: String): RandomForestClassifier:设置原始预测结果列的名称。 -
setProbabilityCol(value: String): RandomForestClassifier:设置概率预测结果列的名称。 -
setWeightCol(value: String): RandomForestClassifier:设置样本权重列的名称。 -
setMaxBins(value: Int): RandomForestClassifier:设置连续特征离散化的最大箱数。 -
fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Array[RandomForestClassificationModel]:使用给定的训练数据集和参数网格搜索拟合多个随机森林模型,并返回一个包含多个训练好的模型的数组。 -
copy(extra: ParamMap): RandomForestClassifier:复制当前实例,可选地带有额外的参数。
这些方法允许您设置和调整随机森林模型的各种参数,以及在训练过程中控制模型的行为。通过适当选择和设置这些参数,可以优化模型的性能和预测准确度。
示例
以下是使用RandomForestClassifier进行分类任务的示例代码:
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{
IndexToString, StringIndexer, VectorAssembler}
import org.apache.spark.ml.Pipeline
// 读取训练数据集
val data = spark.read.format("csv")
.option("header", "true")
.option("inferSchema", "true")
.load("path/to/training_data.csv")
// 创建特征向量列
val featureColumns = Array("feature1", "feature2", "feature3")
val assembler = new VectorAssembler()
.setInputCols(featureColumns)
.setOutputCol("features")
val assembledData = assembler.transform(data)
// 对标签进行索引化
val labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(assembledData)
// 拆分数据集为训练集和测试集
val Array(trainingData, testData) = assembledData.randomSplit(Array(0.7, 0.3))
// 创建随机森林分类器
val rf = new RandomForestClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("features")
.setNumTrees(10)
// 将索引化的标签转换回原始标签
val labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels)
// 构建Pipeline
val pipeline = new Pipeline()
.setStages(Array(labelIndexer, rf, labelConverter))
// 训练模型
val model = pipeline.fit(trainingData)
// 在测试集上进行预测
val predictions = model.transform(testData)
// 评估模型性能
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println("Accuracy: " + accuracy)
在这个示例中,首先加载训练数据集,并创建特征向量列。然后对标签进行索引化,并将数据集拆分为训练集和测试集。接下来,创建一个RandomForestClassifier对象,并设置相关参数。然后,使用Pipeline构建一个包含数据转换和模型训练的流水线。通过调用fit方法来训练模型。
最后,在测试集上进行预测并评估模型的性能。在这个示例中,我们使用了多分类准确度(accuracy)作为评估指标。
中文源码
class RandomForestClassifier
/**
* 随机森林(Random Forest)分类学习算法。
* 支持二进制和多类标签,以及连续和分类特征。
*/
@Since("1.4.0")
class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestClassifierParams with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("rfc"))
// 为了与Java API兼容性,重写父trait中的参数设置方法。
// TreeClassifierParams中的参数:
/** 设置树的最大深度 */
@Since("1.4.0")
override def setMaxDepth(value: Int): this.type = set(maxDepth, value)
/** 设置每个节点的最大分箱数 */
@Since("1.4.0")
override def setMaxBins(value: Int): this.type = set(maxBins, value)
/** 设置每个节点的最小实例数 */
@Since("1.4.0")
override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
/** 设置节点分裂所需的最小信息增益 */
@Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
/** 设置算法使用的内存上限 */
@Since("1.4.0")
override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
/** 设置是否缓存节点ID */
@Since("1.4.0")
override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
/**
* 设置检查点的频率,即多少次迭代进行一次缓存检查点。
* 仅在设置了cacheNodeIds为true并且在SparkContext中设置了检查点目录时才会使用。
* 必须至少为1。
* 默认值为10。
*/
@Since("1.4.0")
override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
/** 设置不纯度度量方法 */
@Since("1.4.0")
override def setImpurity(value: String): this.type = set(impurity, value)
// TreeEnsembleParams中的参数:
/** 设置子采样率 */
@Since("1.4.0")
override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
/** 设置随机种子 */
@Since("1.4.0")
override def setSeed(value: Long): this.type = set(seed, value)
// RandomForestParams中的参数:
/** 设置树的数量 */
@Since("1.4.0")
override def setNumTrees(value: Int): this.type = set(numTrees, value)
/** 设置特征子集策略 */
@Since("1.4.0")
override def setFeatureSubsetStrategy(value: String): this.type =
set(featureSubsetStrategy, value)
override protected def train(
dataset: Dataset[_]): RandomForestClassificationModel = instrumented {
instr =>
instr.logPipelineStage(this)
instr.logDataset(dataset)
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val numClasses: Int = getNumClasses(dataset)
if (isDefined(thresholds)) {
require($(thresholds).length == numClasses, this.getClass.getSimpleName +
".train() called with non-matching numClasses and thresholds.length." +
s" numClasses=$numClasses, but thresholds has length ${
$(thresholds).length}")
}
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
instr.logParams(this, labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol,
impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain,
minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval)
val trees = RandomForest
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeClassificationModel])
val numFeatures = oldDataset.first().features.size
instr.logNumClasses(numClasses)
instr.logNumFeatures(numFeatures)
new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)
}
@Since("1.4.1")
override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
}
object RandomForestClassifier
object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifier] {
/** 支持的不纯度度量方法:熵(entropy)、基尼指数(gini) */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
/** 支持的特征子集策略:自动选择(auto)、全部(all)、三分之一(onethird)、平方根(sqrt)、对数(log2) */
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] =
TreeEnsembleParams.supportedFeatureSubsetStrategies
/** 加载模型 */
@Since("2.0.0")
override def load(path: String): RandomForestClassifier = super.load(path)
}
这部分代码定义了RandomForestClassifier对象,提供了一些静态方法和常量:
supportedImpurities:支持的不纯度度量方法,包括熵(entropy)和基尼指数(gini)。supportedFeatureSubsetStrategies:支持的特征子集策略,包括自动选择(auto)、全部(all)、三分之一(onethird)、平方根(sqrt)和对数(log2)。load方法:用于加载模型。
class RandomForestClassificationModel
/**
* 用于分类的随机森林(Random Forest)模型。
*
* @param _trees 集成中的决策树数组。
* 注意:这些树的父节点为null。
*/
@Since("1.4.0")
class RandomForestClassificationModel private[ml] (
@Since("1.5.0") override val uid: String,
private val _trees: Array[DecisionTreeClassificationModel],
@Since("1.6.0") override val numFeatures: Int,
@Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel]
with MLWritable with Serializable {
require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
/**
* 构造随机森林分类模型,所有树的权重相等。
*
* @param trees 组成模型的决策树数组
*/
private[ml] def this(
trees: Array[DecisionTreeClassificationModel],
numFeatures: Int,
numClasses: Int) =
this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)
@Since("1.4.0")
override def trees: Array[DecisionTreeClassificationModel] = _trees
// 注意:我们可能会在以后添加根据树性能进行加权的支持。
private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights
/**
* 将模型应用于数据集,生成预测结果的转换操作。
*
* @param dataset 输入的数据集
* @return 包含预测结果的新DataFrame
*/
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
val predictUDF = udf {
(features: Any) =>
bcastModel.value.predict(features.asInstanceOf[Vector])
}
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
/**
* 根据输入特征向量生成原始预测结果。
*
* @param features 输入的特征向量
* @return 原始预测结果向量
*/
override protected def predictRaw(features: Vector): Vector = {
// TODO: 当我们添加通用的Bagging类时,将在那里处理:SPARK-7128
// 使用多数表决进行分类。
// 目前忽略树权重,因为都是1.0。
val votes = Array.fill[Double](numClasses)(0.0)
_trees.view.foreach {
tree =>
val classCounts: Array[Double] = tree.rootNode.predictImpl(features).impurityStats.stats
val total = classCounts.sum
if (total != 0) {
var i = 0
while (i < numClasses) {
votes(i) += classCounts(i) / total
i += 1
}
}
}
Vectors.dense(votes)
}
/**
* 将原始预测结果转换为概率结果。
*
* @param rawPrediction 原始预测结果向量
* @return 概率预测结果向量
*/
override

本文深入探讨Spark ML中的RandomForestClassifier,介绍其原理、方法和示例代码。随机森林由多个决策树构成,通过自助采样和特征随机选择训练。文章还分析了模型的构建过程、特征重要性评估以及参数调优,帮助理解如何使用和优化随机森林模型。
最低0.47元/天 解锁文章
4070

被折叠的 条评论
为什么被折叠?



