spark高级数据分析实战---用决策树预测森林植被

这篇博客分享了使用Spark进行高级数据分析的实战经验,特别是通过决策树模型预测森林植被。博主提到,由于时间限制,之前的推荐系统内容将后续补充,而本次的决策树程序已在较差性能的机器上运行,期待读者在更强大的集群上尝试以获得更好的效果。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

这是我写的这本书的第二个程序,这几天一直研究storm,没时间写,第一个推荐系统由于时间我没及时发回头会补充给大家,这个找了时间参照书上写的,希望对大家有帮助。


package mllib.tree
import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkConf}

/**
  * Created by 汪本成 on 2016/7/12.
  */
object trainCovtype {

  //开始时间
  var beg = System.currentTimeMillis()

  //屏蔽不必要的日志显示在终端上
  //Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
  //Logger.getLogger("org.apache.eclipse.jetty.server").setLevel(Level.OFF)

  //创建入口对象
  val conf = new SparkConf().setAppName("trainCovtype").setMaster("local")
  val sc= new SparkContext(conf)

  val HDFS_COVDATA_PATH = "hdfs://node1:9000/user/spark/sparkLearning/mllib/covtype.data"
  val rawData = sc.textFile(HDFS_COVDATA_PATH)

  //设置LabeledPoint格式
  val data = rawData.map{
    line =>
      val values = line.split(",").map(_.toDouble)
      // init返回除最后一个值之外的所有值,最后一列是目标
      val FeatureVector = Vectors.dense(values.init)
      //决策树要求(目标变量)label0开始,所以要减一
      val label = values.last - 1
      LabeledPoint(label, FeatureVector)
  }

  //分成训练集(80%),交叉验证集(10%),测试集(10%)
  val Array(trainData, cvData, testData) = data.randomSplit(Array(0.8, 0.1, 0.1))
  trainData.cache()
  cvData.cache()
  testData.cache()

  //新建决策树
  val numClass = 7  //分类数量
  val categoricalFeaturesInfo = Map[Int, Int]()  //map存储类别(离散)特征及每个类特征对应值(类别)的数量
  val impurity = "gini"  //纯度计算方法,用于信息增益的计算
  val maxDepth = 4  //树的最大高度
  val maxBins = 100   // 用于分裂特征的最大划分数量

  //训练分类决策树模型
  val model = DecisionTree.trainClassifier(trainData, numClass, categoricalFeaturesInfo, impurity, maxDepth, maxBins)

  val metrics = getMetrics(model,cvData)
  //计算精确度(样本比例)
  val precision = metrics.precision
  //计算每个样本的准确度(召回率)
  val recall = (0 until 7).map(     //DecisionTreeModel模型的类别号从0开始
    cat => (metrics.precision(cat), metrics.recall(cat))
  )
  //混淆矩阵
  val confusionMatrix = metrics.confusionMatrix


  //预测训练数据集
  val trainPriorProbabilities = classProbabilities(trainData)
  //预测cv  val cvPriorProbabilities = classProbabilities(cvData)
  //将所有类别在训练集合cv集出现的概率相乘,然后把结果相加,最后得到对准确度评估
  val two_probabilities = trainPriorProbabilities.zip(cvPriorProbabilities).map {  //cv集中的莫个类别的概率结成对,相乘后再相加
    case (trainProd, cvProd) => trainProd * cvProd
  }.sum

  /**决策树的优化**/
  val evaluations1 =
    for (impurities <- Array("gini", "entropy");
         depth <- Array(1, 20);
         bins <- Array(10, 300)
    )yield {

      val model = DecisionTree.trainClassifier(
        trainData,
        numClass,
        categoricalFeaturesInfo,
        impurities,
        depth,
        bins)
      val predictionAndLabels = cvData.map(
        example =>
          (model.predict(example.features), example.label)
      )
      val accuracy = new MulticlassMetrics(predictionAndLabels).precision
      ((impurities, depth, bins), accuracy)
    }
  //按照第二个值(准确度)降序排序
  val result1 = evaluations1.sortBy(_._2).reverse

  //对优化决策树让训练集集合cv数据集进行评估
  val evaluations2 =
    for (impurities <- Array("gini", "entropy");
       depth <- Array(1, 20);
       bins <- Array(10, 30)
    )yield {
      val model =
        DecisionTree.trainClassifier(
          trainData.union(cvData),
          numClass,
          categoricalFeaturesInfo,
          impurities,
          depth,
          bins
        )
      val predictionAndLabels = trainData.union(cvData).map(
        example =>
          (model.predict(example.features), example.label)
      )
      val accuracy = new MulticlassMetrics(predictionAndLabels).precision
      ((impurities, depth, bins), accuracy)
    }
  //按照第二个值(准确度)降序排序
  val result2 = evaluations2.sortBy(_._2).reverse

  //结束时间
  var end = System.currentTimeMillis()

  //耗时时间
  var castTime = end - beg



  def main(args: Array[String]) {

    println("========================================================================================")
    //精确度(样本比例)
    println("精确度: " + precision)
    println("========================================================================================")
    //准确度(召回率)
    println("准确度: ")
    recall.foreach(println)
    println("========================================================================================")
    //cvtrain的数据集结合对准确度的评估
    println("cvtrain的数据集结合对准确度: " + two_probabilities)
    println("========================================================================================")
    //混淆矩阵
    println("混淆矩阵如下: ")
    println(confusionMatrix)
    println("========================================================================================")
    //cvData下的决策树不同条件下的准确度降序排序
    println("cvData下的决策树不同条件下的准确度降序排如下: ")
    result1.foreach(println)
    println("========================================================================================")
    //cvData结合trainData下不同条件下的准确度降序排序
    println("cvData结合trainData下不同条件下的准确度降序排序如下: ")
    result2.foreach(println)
    println("========================================================================================")
    println(" 运行程序耗时: " + castTime/1000 + "s")

  }

  /**
    * 在训练集构建DecisionTreeModel
    *
    * @param model
    * @param data
    * @return
    */
  def getMetrics(model: DecisionTreeModel, data: RDD[LabeledPoint]): MulticlassMetrics = {
    val predictionsAndLabels = data.map(example => (model.predict(example.features), example.label))
    new MulticlassMetrics(predictionsAndLabels)
  }
  /**
    * 按照类别在训练集出现的比例预测类别
    *
    * @param data
    * @return
    */
  def classProbabilities(data: RDD[LabeledPoint]): Array[Double] = {
    //计算数据中每个类别的样本数(类别, 样本数)
    val countsByCategory = data.map(_.label).countByValue()
    //对类别的样本数进行排序并取出样本数
    val counts = countsByCategory.toArray.sortBy(_._1).map(_._2)
    counts.map(_.toDouble / counts.sum)
  }
}


运行结果如下,因为我的机器比较差,时间长点,跑在集群上就不一样了,大家可以试试

========================================================================================
精确度: 0.6980966928106819
========================================================================================
准确度: 
(0.68203722951904,0.6760325934251195)
(0.7190111755099426,0.7894107720433132)
(0.6342031686859273,0.7663288288288288)
(0.4682926829268293,0.35294117647058826)
(0.0,0.0)
(0.7692307692307693,0.029994001199760048)
(0.7007633587786259,0.4443368828654405)
========================================================================================
cv和train的数据集结合对准确度: 0.37763488623859276
========================================================================================
混淆矩阵如下: 
14436.0  6548.0   8.0     0.0   0.0  3.0   359.0  
5615.0   22454.0  319.0   18.0  0.0  5.0   33.0   
0.0      755.0    2722.0  68.0  0.0  7.0   0.0    
0.0      0.0      176.0   96.0  0.0  0.0   0.0    
0.0      899.0    13.0    0.0   0.0  0.0   0.0    
0.0      540.0    1054.0  23.0  0.0  50.0  0.0    
1115.0   33.0     0.0     0.0   0.0  0.0   918.0  
========================================================================================
cvData下的决策树不同条件下的准确度降序排如下: 
((entropy,20,300),0.9146686803851236)
((gini,20,300),0.9035989496627593)
((entropy,20,10),0.8961676420615443)
((gini,20,10),0.8915338012940429)
((gini,1,300),0.6366039095886179)
((gini,1,10),0.6358487651672473)
((entropy,1,300),0.4881665436696586)
((entropy,1,10),0.4881665436696586)
========================================================================================
cvData结合trainData下不同条件下的准确度降序排序如下: 
((entropy,20,30),0.9509307620195527)
((gini,20,30),0.9386290152863074)
((entropy,20,10),0.9344984598901834)
((gini,20,10),0.9305190457058677)
((gini,1,30),0.6342325278845969)
((gini,1,10),0.6340756471330999)
((entropy,1,30),0.4871319520174482)
((entropy,1,10),0.4871319520174482)
========================================================================================
 运行程序耗时: 331s
16/07/18 18:27:11 INFO SparkContext: Invoking stop() from shutdown hook
16/07/18 18:27:12 INFO SparkUI: Stopped Spark web UI at http://192.168.43.1:4040
16/07/18 18:27:12 INFO MapOutputTrackerMasterEndpoint: MapOutputTrackerMasterEndpoint stopped!
16/07/18 18:27:12 INFO MemoryStore: MemoryStore cleared
16/07/18 18:27:12 INFO BlockManager: BlockManager stopped
16/07/18 18:27:12 INFO BlockManagerMaster: BlockManagerMaster stopped
16/07/18 18:27:12 INFO OutputCommitCoordinator$OutputCommitCoordinatorEndpoint: OutputCommitCoordinator stopped!
16/07/18 18:27:12 INFO SparkContext: Successfully stopped SparkContext
16/07/18 18:27:12 INFO ShutdownHookManager: Shutdown hook called
16/07/18 18:27:12 INFO ShutdownHookManager: Deleting directory C:\Users\Administrator\AppData\Local\Temp\spark-d7aaabfd-bd23-4a4d-b3a4-c21fd38eea98
16/07/18 18:27:12 INFO RemoteActorRefProvider$RemotingTerminator: Shutting down remote daemon.


Process finished with exit code 0


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值