object ForestTest {
def main(args: Array[String]): Unit = {
val conf=new SparkConf().setAppName("DesionTrain").setMaster("local[2]")
val sc=new SparkContext(conf)
// 加载数据
val data=sc.textFile("f://rf.csv").map(lines=>{
val fields=lines.split(",")
val lable=fields(fields.length-1).toDouble
val features=fields.slice(1,fields.length-1).map(x=>x.toDouble)
LabeledPoint(lable,Vectors.dense(features))
})
val labe=data.map(_.label)
// 配置决策树的参数
val model= RandomForest.trainClassifier(data,9,Map[Int,Int](),20,"auto","entropy",30,300)
val predictionAndLabel = data.map { point =>
val score = model.predict(point.features)
(score)
}
predictionAndLabel.foreach(x=> println(x))
//测试准确率
val acc=labe.zip(predictionAndLabel).filter(x=>{
x._1.equals(x._2)
}).count()/labe.count().toDouble
println("Forest预测患者在医院花费的准确率是")
println(acc)
}
}
rf.csv格式如下:
1,1,1,30,2201,0,1,20,210105,51,3,5,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3.2,0,3.2,3.2,3.2,3.2,0,1,0,1,1,0,10,4680,1,1,80,0,1,2,2,5,0,1,2101,1,1,2101,3,1,1,1,1,1,1,0,0,0,0,1,5467,3253,300133,300133,13,1,300102,4535,0,0,0,0,75357,8907.53,1,1,1,1,1,1,1,1,1,1,1,0,0,1,1,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,55