spark mllib源码分析之随机森林(Random Forest)(四)

本文深入探讨Spark MLlib中随机森林的节点分裂过程,包括数据统计、节点分裂的最佳切分点寻找,以及连续和无序、有序特征的处理方法。文章详尽分析了如何计算节点的最优分裂点及其增益,以及如何进行节点的分裂和循环训练。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

spark源码分析之随机森林(Random Forest)(一)
spark源码分析之随机森林(Random Forest)(二)
spark源码分析之随机森林(Random Forest)(三)
spark源码分析之随机森林(Random Forest)(五)

6.4. node分裂

逻辑主要在DecisionTree.findBestSplits函数中,是RF训练最核心的部分

DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
        treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
6.4.1. 数据统计

数据统计分成两部分,先在各个partition上分别统计,再累积各partition成全局统计。

6.4.1.1. 取出node的特征子集
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)

取出各node的特征子集,如果不需要抽样则为None;否则返回Map[Int, Array[Int]],其实就是将之前treeToNodeToIndexInfo中的NodeIndexInfo转换为map结构,将其作为广播变量nodeToFeaturesBc。

6.4.1.2. 分区统计

一系列函数的调用链,我们逐层分析

val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
      input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
        // Construct a nodeStatsAggregators array to hold node aggregate stats,
        // each node will have a nodeStatsAggregator
        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
            Some(nodeToFeatures(nodeIndex))
          }
          new DTStatsAggregator(metadata, featuresForNode)
        }

        // iterator all instances in current partition and update aggregate stats
        points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))

        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
        // which can be combined with other partition using `reduceByKey`
        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
      }
    } else {
      input.mapPartitions { points =>
        // Construct a nodeStatsAggregators array to hold node aggregate stats,
        // each node will have a nodeStatsAggregator
        val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
          val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
            Some(nodeToFeatures(nodeIndex))
          }
          new DTStatsAggregator(metadata, featuresForNode)
        }

        // iterator all instances in current partition and update aggregate stats
        points.foreach(binSeqOp(nodeStatsAggregators, _))

        // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
        // which can be combined with other partition using `reduceByKey`
        nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
      }
    }

首先对每个partition构造一个DTStatsAggregator数组,长度是node的个数,注意这里实际使用的是数组,node怎样与自己的aggregator的对应?前面我们提到NodeIndexInfo的第一个成员是groupIndex,其值就是node的次序,和这里aggregator数组index其实是对应的,也就是说可以从NodeIndexInfo中取得groupIndex,然后作为数组index取得对应node的agg。DTStatsAggregator的入参是metadata和每个node的特征子集。然后将每个点统计到DTStatsAggregator中,其中调用了内部函数binSeqOp,

 /**
     * Performs a sequential aggregation over a partition.
     *
     * Each data point contributes to one node. For each feature,
     * the aggregate sufficient statistics are updated for the relevant bins.
     *
     * @param agg  Array storing aggregate calculation, with a set of sufficient statistics for
     *             each (node, feature, bin).
     * @param baggedPoint   Data point being aggregated.
     * @return  agg
     */
    def binSeqOp(
        agg: Array[DTStatsAggregator],
        baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
    //对每个node
      treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
        val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
          bins, metadata.unorderedFeatures)
        nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
      }

      agg
    }

首先调用函数predictNodeIndex计算nodeIndex,如果是首轮或者叶子节点,直接返回node.id;如果不是首轮,因为传入的是每棵树的root node,就从root node开始,逐渐往下判断该point应该是属于哪个node的,因为我们已经对node进行了分裂,这里其实实现了样本的划分。举个栗子,当前node如果是root的左孩子节点,而point预测节点应该属于右孩子,则调用nodeBinSepOp时就直接返回了,不会将这个point统计进去,用不大的时间换取样本集划分的空间,还是比较巧妙的。

/**
   * Get the node index corresponding to this data point.
   * Thi
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值