随机森林(random forest)和GBDT都是属于集成学习(ensemble learning)的范畴。集成学习下有两个重要的策略Bagging和Boosting。
Bagging算法是这样做的:每个分类器都随机从原样本中做有放回的采样,然后分别在这些采样后的样本上训练分类器,然后再把这些分类器组合起来。简单的多数投票一般就可以。其代表算法是随机森林。Boosting的意思是这样,他通过迭代地训练一系列的分类器,每个分类器采用的样本分布都和上一轮的学习结果有关。其代表算法是AdaBoost, GBDT。
val conf=new SparkConf().setAppName("GBDTExample")
val sc=new SparkContext(conf)
val sqlcontext=new SQLContext(sc)
import sqlcontext.implicits._
val data = MLUtils.loadLibSVMFile(sc,"/tmp/sample_libsvm_data.txt").toDF("label","features")
val splits=data.randomSplit(Array(0.7,0.3))
val (trainData,testData)=(splits(0),splits(1))
val labelIndexer = new StringIndexer()
.setInputCol("lable")
.setOutputCol("indexLable")
.fit(data)
val featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4)
.fit(data)
val gbdt=new GBTClassifier()
.setLabelCol("indexLable")
.setFeaturesCol("indexedFeatures")
.setMaxIter(10)
val lableConvert=new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictionLable")
.setLabels(labelIndexer.labels)
val pipeline=new Pipeline()
.setStages(Array(labelIndexer,featureIndexer,gbdt,lableConvert))
val model=pipeline.fit(trainData)
val predications=model.transform(testData)
predications.select("predictionLable", "label", "features").show(5)
val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel]
println("Learned classification GBT model:\n" + gbtModel.toDebugString)
}