在官方API文档可以查到用法。
def trainClassifier(input: RDD[LabeledPoint], numClasses: Int, categoricalFeaturesInfo: Map[Int, Int], impurity: String, maxDepth: Int, maxBins: Int): DecisionTreeModel
Method to train a decision tree model for binary or multiclass classification.
input Training dataset: RDD of org.apache.spark.mllib.regression.LabeledPoint. Labels should take values {0, 1, ..., numClasses-1}.
numClasses number of classes for classification.
categoricalFeaturesInfo Map storing arity of categorical features. E.g., an entry (n -> k) indicates that feature n is categorical with k categories indexed from 0: {0, 1, ..., k-1}.
impurity Criterion used for information gain calculation. Supported values: "gini" (recommended) or "entropy".
maxDepth Maximum depth of the tree. E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. (suggested value: 5)
maxBins maximum number of bins used for splitting features (suggested value: 32)
returns DecisionTreeModel that can be used for prediction
kyphosis 数据集
kyphosis数据集的各列含义:
数据集是从儿童接受外科脊柱矫正手术中来的,数据集有4列、81行(81个病例)。
1、kyphosis:采取手术后依然出现脊柱后凸(驼背)的因子
2、Age:单位是“月”
3、Number:代表进行手术的脊柱椎骨的数目
4、Start:在脊柱上从上往下数、参与手术的第一节椎骨所在的序号
absent 158 3 14
present 128 4 5
absent 2 5 1
absent 1 4 15
absent 1 2 16
absent 61 2 17
absent 37 3 16
absent 113 2 16
present 59 6 12
present 82 5 14
absent 148 3 16
absent 18 5 2
absent 1 4 12
absent 168 3 18
absent 1 3 16
absent 78 6 15
absent 175 5 13
absent 80 5 16
absent 27 4 9
absent 22 2 16
present 105 6 5
present 96 3 12
absent 131 2 3
present 15 7 2
absent 9 5 13
absent 8 3 6
absent 100 3 14
absent 4 3 16
absent 151 2 16
absent 31 3 16
absent 125 2 11
absent 130 5 13
absent 112 3 16
absent 140 5 11
absent 93 3 16
absent 1 3 9
present 52 5 6
absent 20 6 9
present 91 5 12
present 73 5 1
absent 35 3 13
absent 143 9 3
absent 61 4 1
absent 97 3 16
present 139 3 10
absent 136 4 15
absent 131 5 13
present 121 3 3
absent 177 2 14
absent 68 5 10
absent 9 2 17
present 139 10 6
absent 2 2 17
absent 140 4 15
absent 72 5 15
absent 2 3 13
present 120 5 8
absent 51 7 9
absent 102 3 13
present 130 4 1
present 114 7 8
absent 81 4 1
absent 118 3 16
absent 118 4 16
absent 17 4 10
absent 195 2 17
absent 159 4 13
absent 18 4 11
absent 15 5 16
absent 158 5 14
absent 127 4 12
absent 87 4 16
absent 206 4 10
absent 11 3 15
absent 178 4 15
present 157 3 13
absent 26 7 13
absent 120 2 13
present 42 7 6
absent 36 4 13
这个数据集,缺省是用多个空格隔开的,所以写了一个python脚本将其处理成一个空格隔开。
import re
f = open("e:/MyProject/SparkDiscover/data/kyphosis.data","r")
line = f.readline()
while line:
line = f.readline()
line = line.strip("\n")
out = re.sub(r"\s{2,}", " ", line)
print(out)
测试代码
package classify
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
object DsTree {
def parseLine(line: String): LabeledPoint = {
val parts = line.split(" ")
val vd: Vector = Vectors.dense(parts(1).toInt, parts(2).toInt, parts(3).toInt)
var target = 0
parts(0) match {
case "absent" => target = 0;
case "present" => target = 1;
}
return LabeledPoint(target, vd)
}
def main(args: Array[String]) {
val conf = new SparkConf().setMaster(args(0)).setAppName("Iris")
val sc = new SparkContext(conf)
val data = sc.textFile(args(1)).map(parseLine(_))
val splits = data.randomSplit(Array(0.7, 0.3), seed = 11L)
val trainData = splits(0)
val testData = splits(1)
val numClasses = 2 //分类数量
val categoricalFeaturesInfo = Map[Int, Int]() //输入格式
val impurity = "entropy" //信息增益计算方式 gini
val maxDepth = 5 //树的高度
val maxBins = 3 //分裂数据集
val model = DecisionTree.trainClassifier(trainData, numClasses, categoricalFeaturesInfo,
impurity, maxDepth, maxBins)
val predictionAndLabel = testData.map(p => (model.predict(p.features), p.label))
predictionAndLabel.foreach(println)
val metrics = new MulticlassMetrics(predictionAndLabel)
val precision = metrics.precision
println("Precision = " + precision)
}
}