scala源代码走读

这篇博客主要探讨了Scala MLlib库中的KMeans聚类算法实现。文章详细分析了KMeans模型的收敛过程,并提及了在Scala中数据持久化的重要性,包括cache()和persist()方法在节省IO时间和计算时间上的作用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

今天先分析位于scala.mllib.clustering中最简单的KMeans模型,即文件KMeans.scala。

KMeans作为较简单的聚类算法,mllib中KMeans的实现方法也很简单。

KMeans类的定义
class KMeans private (
    private var k: Int,                        // 簇的个数
    private var maxIterations: Int,            // 模型迭代次数
    private var initializationMode: String,    // 初始化簇内中心点的算法
    private var initializationSteps: Int,      // 默认值是2,按照源代码说法这个一般不用调整
    private var epsilon: Double,               // 用于判断聚类中心收敛的距离阈值
    private var seed: Long,                    // 距离度量方法
    private var distanceMeasure: String)
    
    private var initialModel: Option[KMeansModel] = None  // 可以人为选择初始簇内中心点
    def setInitialModel(model: KMeansModel): this.type = {
        require(model.k == k, "mismatched cluster count")
        initialModel = Some(model)
        this
    }
    

初始化簇内中心点的算法分为"random" or "k-means||"

KMeans模型收敛过程的代码如下:

private[spark] def run(
      data: RDD[Vector],
      instr: Option[Instrumentation]): KMeansModel = {

    // KMeans模型需要迭代多次,因此数据需要被缓存到cache中
    if (data.getStorageLevel == StorageLevel.NONE) {
      logWarning("The input data is not directly cached, which may hurt performance if its"
        + " parent RDDs are also uncached.")
    }

    // Compute squared norms and cache them.  计算数据的二范数,即计算x*x;
    val norms = data.map(Vectors.norm(_, 2.0))
    norms.persist()                        // 对数据进行持久化操作,此处持久化到内存中
    // 将数据做成 (数据点, norm值)
    val zippedData = data.zip(norms).map { case (v, norm) => 
      new VectorWithNorm(v, norm)
    }
    val model = runAlgorithm(zippedData, instr)
    norms.unpersist()

    // Warn at the end of the run as well, for increased visibility.
    if (data.getStorageLevel == StorageLevel.NONE) {
      logWarning("The input data was not directly cached, which may hurt performance if its"
        + " parent RDDs are also uncached.")
    }
    model
  }

关于scala中数据的持久化,是在数据需要被多次使用时,通过数据持久化,以减少IO时间从而节约计算时间,详细可参考https://blog.youkuaiyun.com/asd136912/article/details/80885136

持久化的方法分为cache() 和 persist(),区别在于cache方法默认且只能缓存到内存,而persist方法自定义缓存级别

KMeans的核心

private def runAlgorithm(
      data: RDD[VectorWithNorm],
      instr: Option[Instrumentation]): KMeansModel = {

    val sc = data.sparkContext

    val initStartTime = System.nanoTime()

    // 距离度量方法,用于计算点点之间的距离
    val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure)
    // 生成簇内中心点
    val centers = initialModel match {
      case Some(kMeansCenters) =>
        kMeansCenters.clusterCenters.map(new VectorWithNorm(_))  // 用户自己选择的中心点
      case None =>
        if (initializationMode == KMeans.RANDOM) {
          initRandom(data)                                       // 随机选择中心点
        } else {
          initKMeansParallel(data, distanceMeasureInstance)  // 使用Parallel方法生成中心点
        }
    }
    val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
    logInfo(f"Initialization with $initializationMode took $initTimeInSeconds%.3f seconds.")

    var converged = false
    var cost = 0.0
    var iteration = 0

    val iterationStartTime = System.nanoTime()

    instr.foreach(_.logNumFeatures(centers.head.vector.size))

    // Execute iterations of Lloyd's algorithm until converged
    while (iteration < maxIterations && !converged) {
      val costAccum = sc.doubleAccumulator   // 创建累加器
      val bcCenters = sc.broadcast(centers)  // 创建广播变量

      // Find the new centers
      val collected = data.mapPartitions { points =>
        val thisCenters = bcCenters.value
        val dims = thisCenters.head.vector.size

        val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims)) // 创建二维数组,行数为中心点个数,列数为点的纬度
        val counts = Array.fill(thisCenters.length)(0L) // 创建一维数组

        points.foreach { point =>
          // 计算点与所有中心点的距离,返回最近的点及距离
          val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)  
          costAccum.add(cost)
          distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
          counts(bestCenter) += 1  // 簇内个数计数
        }
        // sum中记录了同一个簇中所有点相加的和
        counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator
      }.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
        axpy(1.0, sum2, sum1)
        (sum1, count1 + count2)
      }.collectAsMap()

      if (iteration == 0) {
        instr.foreach(_.logNumExamples(collected.values.map(_._2).sum))
      }

      val newCenters = collected.mapValues { case (sum, count) =>
        distanceMeasureInstance.centroid(sum, count)
      }

      bcCenters.destroy()  // 中心点重新计算了

      // Update the cluster centers and costs
      converged = true
      newCenters.foreach { case (j, newCenter) =>
        if (converged &&
         !distanceMeasureInstance.isCenterConverged(centers(j), newCenter, epsilon)) 
        // 判断更新的中心点与原有中心点的距离,如果小于阈值,则认为算法收敛
        {
          converged = false
        }
        centers(j) = newCenter
      }

      cost = costAccum.value
      iteration += 1
    }

    val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9
    logInfo(f"Iterations took $iterationTimeInSeconds%.3f seconds.")

    if (iteration == maxIterations) {
      logInfo(s"KMeans reached the max number of iterations: $maxIterations.")
    } else {
      logInfo(s"KMeans converged in $iteration iterations.")
    }

    logInfo(s"The cost is $cost.")

    new KMeansModel(centers.map(_.vector), distanceMeasure, cost, iteration)
  }

 

第1章,“可伸展的语言”,给出了Scala的设计,和它后面的理由,历史的概要。 第2章,“Scala的第一步”,展示给你如何使用Scala完成若干种基本编程任务,而不牵涉过多关于如何工作的细节。本章的目的是让你的手指开始敲击并执行Scala代码。 第3章,“Scala的下一步”,演示更多的基本编程任务来帮助你更快速地上手Scala。本章之后,你将能够开始在简单的脚本任务中使用Scala。 第4章,“类和对象”,通过描述面向对象语言的基本建设模块和如何编译及运行Scala程序的教程开始有深度地覆盖Scala语言。 第5章,“基本类型和操作”,覆盖了Scala的基本类型,它们的文本,你可以执行的操作,优先级和关联性是如何工作的,还有什么是富包装器。 第6章,“函数式对象”,进入了Scala面向对象特征的更深层次,使用函数式(即,不可变)分数作为例子。 第7章,“内建控制结构”,显示了如何使用Scala的内建控制结构,如,if,while,for,try和match。 第8章,“函数和闭包”,深度讨论了函数式语言的基础建设模块,函数。 ...... 第31章,“组合子解析”,显示了如何使用Scala的解析器组合子库来创建解析器。 第32章,“GUI编程”,展示了使用Scala库简化基于Swing的GUI编程的快速旅程。 第33章,“SCell电子表”,通过展示一个完整的电子表的实现,集中演示了Scala的一切。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值