今天先分析位于scala.mllib.clustering中最简单的KMeans模型,即文件KMeans.scala。
KMeans作为较简单的聚类算法,mllib中KMeans的实现方法也很简单。
KMeans类的定义
class KMeans private (
private var k: Int, // 簇的个数
private var maxIterations: Int, // 模型迭代次数
private var initializationMode: String, // 初始化簇内中心点的算法
private var initializationSteps: Int, // 默认值是2,按照源代码说法这个一般不用调整
private var epsilon: Double, // 用于判断聚类中心收敛的距离阈值
private var seed: Long, // 距离度量方法
private var distanceMeasure: String)
private var initialModel: Option[KMeansModel] = None // 可以人为选择初始簇内中心点
def setInitialModel(model: KMeansModel): this.type = {
require(model.k == k, "mismatched cluster count")
initialModel = Some(model)
this
}
初始化簇内中心点的算法分为"random" or "k-means||"
KMeans模型收敛过程的代码如下:
private[spark] def run(
data: RDD[Vector],
instr: Option[Instrumentation]): KMeansModel = {
// KMeans模型需要迭代多次,因此数据需要被缓存到cache中
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
// Compute squared norms and cache them. 计算数据的二范数,即计算x*x;
val norms = data.map(Vectors.norm(_, 2.0))
norms.persist() // 对数据进行持久化操作,此处持久化到内存中
// 将数据做成 (数据点, norm值)
val zippedData = data.zip(norms).map { case (v, norm) =>
new VectorWithNorm(v, norm)
}
val model = runAlgorithm(zippedData, instr)
norms.unpersist()
// Warn at the end of the run as well, for increased visibility.
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data was not directly cached, which may hurt performance if its"
+ " parent RDDs are also uncached.")
}
model
}
关于scala中数据的持久化,是在数据需要被多次使用时,通过数据持久化,以减少IO时间从而节约计算时间,详细可参考https://blog.youkuaiyun.com/asd136912/article/details/80885136
持久化的方法分为cache() 和 persist(),区别在于cache方法默认且只能缓存到内存,而persist方法自定义缓存级别
KMeans的核心
private def runAlgorithm(
data: RDD[VectorWithNorm],
instr: Option[Instrumentation]): KMeansModel = {
val sc = data.sparkContext
val initStartTime = System.nanoTime()
// 距离度量方法,用于计算点点之间的距离
val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure)
// 生成簇内中心点
val centers = initialModel match {
case Some(kMeansCenters) =>
kMeansCenters.clusterCenters.map(new VectorWithNorm(_)) // 用户自己选择的中心点
case None =>
if (initializationMode == KMeans.RANDOM) {
initRandom(data) // 随机选择中心点
} else {
initKMeansParallel(data, distanceMeasureInstance) // 使用Parallel方法生成中心点
}
}
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(f"Initialization with $initializationMode took $initTimeInSeconds%.3f seconds.")
var converged = false
var cost = 0.0
var iteration = 0
val iterationStartTime = System.nanoTime()
instr.foreach(_.logNumFeatures(centers.head.vector.size))
// Execute iterations of Lloyd's algorithm until converged
while (iteration < maxIterations && !converged) {
val costAccum = sc.doubleAccumulator // 创建累加器
val bcCenters = sc.broadcast(centers) // 创建广播变量
// Find the new centers
val collected = data.mapPartitions { points =>
val thisCenters = bcCenters.value
val dims = thisCenters.head.vector.size
val sums = Array.fill(thisCenters.length)(Vectors.zeros(dims)) // 创建二维数组,行数为中心点个数,列数为点的纬度
val counts = Array.fill(thisCenters.length)(0L) // 创建一维数组
points.foreach { point =>
// 计算点与所有中心点的距离,返回最近的点及距离
val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
costAccum.add(cost)
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
counts(bestCenter) += 1 // 簇内个数计数
}
// sum中记录了同一个簇中所有点相加的和
counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator
}.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
axpy(1.0, sum2, sum1)
(sum1, count1 + count2)
}.collectAsMap()
if (iteration == 0) {
instr.foreach(_.logNumExamples(collected.values.map(_._2).sum))
}
val newCenters = collected.mapValues { case (sum, count) =>
distanceMeasureInstance.centroid(sum, count)
}
bcCenters.destroy() // 中心点重新计算了
// Update the cluster centers and costs
converged = true
newCenters.foreach { case (j, newCenter) =>
if (converged &&
!distanceMeasureInstance.isCenterConverged(centers(j), newCenter, epsilon))
// 判断更新的中心点与原有中心点的距离,如果小于阈值,则认为算法收敛
{
converged = false
}
centers(j) = newCenter
}
cost = costAccum.value
iteration += 1
}
val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9
logInfo(f"Iterations took $iterationTimeInSeconds%.3f seconds.")
if (iteration == maxIterations) {
logInfo(s"KMeans reached the max number of iterations: $maxIterations.")
} else {
logInfo(s"KMeans converged in $iteration iterations.")
}
logInfo(s"The cost is $cost.")
new KMeansModel(centers.map(_.vector), distanceMeasure, cost, iteration)
}