算法:
1. 选择k个簇中心,作为聚类中心 。
2. 计算每个样本点到聚类中心的距离,将每个样品点分配到最近的聚类中心,形成k个簇。
3. 计算每个簇的平均值,并将这个平均值作为新的聚类中心。
4. 反复执行2、3步骤,直到旧质心和新质心的差异小于阈值或迭代次数达到要求为止。
实例:
在IDEA运行,如果是spark-shell命令行窗口,
可使用:paste进入粘贴模式,注意spark-shell下代码中不能有tab
import org.apache.log4j.{ Level, Logger }
import org.apache.spark.{ SparkConf, SparkContext }
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.linalg.Vectors
object KMeans {
def main(args: ArrayString]) {
// 1. 构造spark对象
val conf = new SparkConf().setMaster("local").setAppName("KMeans")
val sc = new SparkContext(conf)
// 去除多余的warn信息
// 2. 读取样本数据,LIBSVM格式
val data = sc.textFile("file:///test/kmeans_data.txt")
val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache()
// 3. 新建KMeans模型,并训练
val initMode = "k-means||"
val numClusters = 2
val numIterations = 20
// 等同于:val model = KMeans.train(parsedData,2,20)
val model = new KMeans()
.setInitializationMode(initMode)
.setK(numClusters)
.setMaxIterations(numIterations)
.run(parsedData)
// 打印聚类中心
model.clusterCenters.foreach(println)
// 4. 误差计算
val WSSSE = model.computeCost(parsedData)
println("Within Set Sum of Squard Errors = " + WSSSE)
// 5. 保存模型、加载模型
val ModelPath = "file:///test/model/KMeans"
model.save(sc, ModelPath)
val sameModel = KMeansModel.load(sc, ModelPath)
}
}
如果报错内存不足:
Java.lang.IllegalArgumentException: System memory 468189184 must be at least 4.718592E8. Please use a larger heap size.
加上语句:
conf.set("spark.testing.memory", "2147480000") //数值大于512m即可
结果:
源码分析:
1. KMeans对象 基于随机梯度下降的SVM分类的伴生对象
train方法 我们常用的train函数,通过调用run方法来训练
2. KMeans类
run方法 训练方法,调用runAlgorithm方法来计算中心点
3. runAlgorithm方法 计算聚类中心点(kmeans算法的核心方法)
initRandom 初始化中心点的方法,支持random和kmeans++两种方法
iteration 迭代计算并更新中心点
4. KMeansModel 模型
predict 预测
1. KMeans对象(定义train方法)
伴生对象:同一个文件中对象名和类名一样,
伴生类和伴生对象的特点是可以相互访问被pr