import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.optimization.HingeGradient
import org.apache.spark.mllib.optimization.SquaredL2Updater
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.SparkSession
/**
* @author XiaoTangBao
* @date 2019/3/6 21:20
* @version 1.0
*/
object SVM {
def main(args: Array[String]): Unit = {
//日志屏蔽
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
val sparkSession = SparkSession.builder().master("local[4]").appName("SVM").getOrCreate()
val sc = sparkSession.sparkContext
//获取数据源
val data = sc.textFile("G:\\mldata\\iris.txt")
//spark SVM中要求label为0和1,实际上内部转化为-1和1
val pddata = data.map(str => str.split('|')).map(arr =>(arr(0).toDouble,arr(1).toDouble,arr(2).toDouble,arr(3).toDouble,arr(4)))
.filter(tuple =>tuple._5.equals("Iris-setosa") || tuple._5.equals("Iris-versicolor"))
.map(tuple => if(tuple._5.equals("Iris-setosa")) (tuple._1,tuple._2,tuple._3,tuple._4,0) else (tuple._1,tuple._2,tuple._3,tuple._4,1))
.map(tuple =>LabeledPoint(tuple._5,Vectors.dense(tuple._1,tuple._2,tuple._3,tuple._4)))
//简单交叉验证
val splitdata = pddata.randomSplit(Array(0.8,0.2))
val traindata = splitdata(0).cache()
val testdata = splitdata(1)
//模型参数设置
val model = new SVMWithSGD()
model.optimizer
.setNumIterations(1000)
.setRegParam(0.1)
.setStepSize(0.3)
.setMiniBatchFraction(0.5)
.setGradient(new HingeGradient())
.setUpdater(new SquaredL2Updater)
val svmModel = model.run(traindata)
svmModel.save(sc,"C:\\users\\Java_Man_China\\desktop\\model1")
val sameModel = SVMModel.load(sc,"C:\\users\\Java_Man_China\\desktop\\model1")
val score = sameModel.predict(testdata.map(lab => lab.features))
val scoreAndLabel = score.zip(testdata.map(lab => lab.label))
scoreAndLabel.foreach(println(_))
//二分类评估器
val metrics = new BinaryClassificationMetrics(scoreAndLabel)
val auroc = metrics.areaUnderROC()
println(auroc)
//多分类评估器
val metric = new MulticlassMetrics(scoreAndLabel)
val ac = metric.accuracy
println(ac)
}
}
算法小白的第一次尝试---SVM实现
最新推荐文章于 2024-10-10 17:22:35 发布