Spark MLlib 入门学习笔记 - 决策树

本文是关于Spark MLLib中决策树的学习笔记,以kyphosis数据集为例,介绍了数据集的背景及包含的特征,包括 kyphosis、Age、Number 和 Start。通过测试代码展示了如何运用Spark进行决策树模型训练和应用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在官方API文档可以查到用法。

def trainClassifier(input: RDD[LabeledPoint], numClasses: Int, categoricalFeaturesInfo: Map[Int, Int], impurity: String, maxDepth: Int, maxBins: Int): DecisionTreeModel
Method to train a decision tree model for binary or multiclass classification.
input Training dataset: RDD of org.apache.spark.mllib.regression.LabeledPoint. Labels should take values {0, 1, ..., numClasses-1}.
numClasses number of classes for classification.
categoricalFeaturesInfo Map storing arity of categorical features. E.g., an entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, ..., k-1}.
impurity Criterion used for information gain calculation. Supported values: "gini" (recommended) or "entropy".
maxDepth Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (suggested value: 5)
maxBins maximum number of bins used for splitting features (suggested value: 32)
returns DecisionTreeModel that can be used for prediction
kyphosis 数据集

kyphosis数据集的各列含义:
数据集是从儿童接受外科脊柱矫正手术中来的,数据集有4列、81行(81个病例)。
1、kyphosis:采取手术后依然出现脊柱后凸(驼背)的因子
2、Age:单位是“月”
3、Number:代表进行手术的脊柱椎骨的数目
4、Start:在脊柱上从上往下数、参与手术的第一节椎骨所在的序号

absent 158 3 14
present 128 4 5
absent 2 5 1
absent 1 4 15
absent 1 2 16
absent 61 2 17
absent 37 3 16
absent 113 2 16
present 59 6 12
present 82 5 14
absent 148 3 16
absent 18 5 2
absent 1 4 12
absent 168 3 18
absent 1 3 16
absent 78 6 15
absent 175 5 13
absent 80 5 16
absent 27 4 9
absent 22 2 16
present 105 6 5
present 96 3 12
absent 131 2 3
present 15 7 2
absent 9 5 13
absent 8 3 6
absent 100 3 14
absent 4 3 16
absent 151 2 16
absent 31 3 16
absent 125 2 11
absent 130 5 13
absent 112 3 16
absent 140 5 11
absent 93 3 16
absent 1 3 9
present 52 5 6
absent 20 6 9
present 91 5 12
present 73 5 1
absent 35 3 13
absent 143 9 3
absent 61 4 1
absent 97 3 16
present 139 3 10
absent 136 4 15
absent 131 5 13
present 121 3 3
absent 177 2 14
absent 68 5 10
absent 9 2 17
present 139 10 6
absent 2 2 17
absent 140 4 15
absent 72 5 15
absent 2 3 13
present 120 5 8
absent 51 7 9
absent 102 3 13
present 130 4 1
present 114 7 8
absent 81 4 1
absent 118 3 16
absent 118 4 16
absent 17 4 10
absent 195 2 17
absent 159 4 13
absent 18 4 11
absent 15 5 16
absent 158 5 14
absent 127 4 12
absent 87 4 16
absent 206 4 10
absent 11 3 15
absent 178 4 15
present 157 3 13
absent 26 7 13
absent 120 2 13
present 42 7 6
absent 36 4 13
这个数据集,缺省是用多个空格隔开的,所以写了一个python脚本将其处理成一个空格隔开。

import re

f = open("e:/MyProject/SparkDiscover/data/kyphosis.data","r")

line = f.readline()
while line:
    line = f.readline()
    line = line.strip("\n")
    out = re.sub(r"\s{2,}", " ", line)
    print(out)
 

测试代码

package classify

import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint

object DsTree {
  def parseLine(line: String): LabeledPoint = {
    val parts = line.split(" ")
    val vd: Vector = Vectors.dense(parts(1).toInt, parts(2).toInt, parts(3).toInt)
    var target = 0
    parts(0) match {
      case "absent" => target = 0;
      case "present" => target = 1;
    }
    return LabeledPoint(target, vd)
  }

  def main(args: Array[String]) {
    val conf = new SparkConf().setMaster(args(0)).setAppName("Iris")
    val sc = new SparkContext(conf)
    val data = sc.textFile(args(1)).map(parseLine(_))

    val splits = data.randomSplit(Array(0.7, 0.3), seed = 11L)
    val trainData = splits(0)
    val testData = splits(1)

    val numClasses = 2 //分类数量
    val categoricalFeaturesInfo = Map[Int, Int]() //输入格式
    val impurity = "entropy" //信息增益计算方式 gini
    val maxDepth = 5 //树的高度
    val maxBins = 3 //分裂数据集
    val model = DecisionTree.trainClassifier(trainData, numClasses, categoricalFeaturesInfo,
      impurity, maxDepth, maxBins)

    val predictionAndLabel = testData.map(p => (model.predict(p.features), p.label))
    predictionAndLabel.foreach(println)

    val metrics = new MulticlassMetrics(predictionAndLabel)
    val precision = metrics.precision
    println("Precision = " + precision)
  }
}




评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值