参考资料链接:
https://github.com/CraigCovey/spark-examples/blob/f8182a6736fd5293dfa03b023eb1423363ba6041/spark-1_6/scala/clustering/kmeans/kmeans_clustering_main.scala
package com.xx.Kmeans_sample
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.feature.{StandardScaler, VectorAssembler}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
object KmeansClusteringMain {
def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setAppName("ReadData").setMaster("local").set("spark.sql.warehouse.dir", "file:///C:/Users/username/IdeaProjects/spark_demo/spark-warehouse")
val sparkSession = SparkSession.builder().config(sparkConf).getOrCreate()
val input_path = "D:/spark_data/iris.csv"
val data = sparkSession.sqlContext.read.format("csv").option("sep", ",")
.option("inferSchema", "true")
.option("header", "true")
.load(input_path)
val predictorVariables : Array[String] = Array("Sepal_Length", "Sepal_Width", "Petal_Length", "Petal_Width")
val assembler = new VectorAssembler()
.setInputCols(predictorVariables)
.setOutputCol("clusteringFeatures")
val scaler = new StandardScaler()
.setInputCol("clusteringFeatures")
.setOutputCol("scaledClusteringFeatures")
.setWithMean(true)
.setWithStd(true)
val kmeansAlgorithm = new KMeans()
.setK(10) // <-- number of clusters
.setSeed(1024)
.setMaxIter(20) // <-- hyperparameter
.setTol(1.0e-05) // <-- hyperparameter
.setFeaturesCol("scaledClusteringFeatures")
.setPredictionCol("columnCategory") // <-- create your own column name
val pipeline = new Pipeline().setStages(Array(assembler, scaler, kmeansAlgorithm))
// Train model
val pipelineModel = pipeline.fit(data)
// Apply model to dataframe
val kmeansPrediction = pipelineModel.transform(data)
kmeansPrediction.show()
// Evaluate clustering by computing Within Set Sum of Squared Errors
val kmeansModel = pipelineModel.stages.last.asInstanceOf[KMeansModel]
val cost = kmeansModel.computeCost(kmeansPrediction)
println("Clustering Cost: " + cost)
// Print cluster centers
val centers = kmeansModel.clusterCenters
println("Cluster Centers:")
centers.foreach(println)
}
}