特别警告:https://www.wandouip.com/t5i28437/ 希望该网站尊重原创,转载请标明出处。
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.SparkSession
import scala.collection.mutable.ArrayBuffer
/**
* @author XiaoTangBao
* @date 2019/3/9 10:03
* @version 1.0
* 基于统计学习方法--李航 例8.1
*/
object AdaBoost {
def main(args: Array[String]): Unit = {
//屏蔽日志
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
//创建会话
val spark = SparkSession.builder().master("local[4]").appName("A2").getOrCreate()
//训练数据
val arr = Array((0,1),(1,1),(2,1),(3,-1),(4,-1),(5,-1),(6,1),(7,1),(8,1),(9,-1))
//初始化权重--各样本具有相同的权值
var weights = ArrayBuffer[Double]()
for(i <- arr) weights.append(1.0 / arr.length)
//生成带标签和权重的训练数据
var trainData = ArrayBuffer[LabeledPoint]()
for(i<-0 until arr.length) trainData.append(LabeledPoint(arr(i)._2,Vectors.dense(arr(i)._1,weights(i))))
//初始化分割点
val cutpoint = ArrayBuffer[Double]()
for(rdd <- arr) cutpoint.append(rdd._1 + 0.5)
//------------准备工作自此结束----------------
//存储每次Gm(x)的系数
val amArr = ArrayBuffer[Double]()
//存储每次迭代的最优Gm(x)
val GmXArr = ArrayBuffer[(Double,Double,(Double,Double)=>Double)]()
var flag = true
while(flag){
//每一轮最优弱分类器Gm(x)
val bestGx = findGmX(trainData.toArray,cutpoint.toArray)
GmXArr.append(bestGx)
//计算Gm(x)的的系数
val am = (0.5 * math.log((1 - bestGx._2)/ bestGx._2)).formatted("%.4f").toDouble
amArr.append(am)
println("am:" +am)
println("em:"+bestGx._2)
println("cutpoint:"+bestGx._1)
Thread.sleep(2000)
//定义规范化因子
var zm = 0.0
for(i<-0 until weights.length){
zm += weights(i) * math.exp(-1 * am * trainData(i).label * bestGx._3(bestGx._1,trainData(i).features(0)))
}
//更新权重
for(i<-0 until weights.length){
weights(i) = (weights(i) / zm) * math.exp(-1 * am * trainData(i).label * bestGx._3(bestGx._1,trainData(i).features(0)))
}
//更新带权重原始数据
for(i<-0 until trainData.length){
trainData(i) = LabeledPoint(trainData(i).label,Vectors.dense(trainData(i).features(0),weights(i)))
}
//当前叠加后的分类器
var flag2 = true
//统计分类正确的数
var tn = 0
for(j<-0 until trainData.length if flag2){
//每一个分类器计算后的结果进行累加
var result = 0.0
for(i<-0 until amArr.length){
result += amArr(i) * GmXArr(i)._3(GmXArr(i)._1,trainData(j).features(0))
}
//若存在误分类,则立即退出
if(math.signum(result)!=trainData(j).label) flag2 = false else tn += 1
//若分类器sign[f(x)]所有数据都分类正确,则立即结束
if(tn == trainData.length){flag2 = false;flag = false}
}
}
//最终的AdaBoost分类器为:[Gm(x)系数,Gm(x)切割点,Gm(x)函数]
val adaBoostModel = ArrayBuffer[(Double,Double,(Double,Double)=>Double)]()
for(i<-0 until amArr.length){
adaBoostModel.append((amArr(i),GmXArr(i)._1,GmXArr(i)._3))
}
var str = "sign ["
for(ada<-adaBoostModel){
str += (ada._1 + "*GmX(" + ada._2 + ")" + " + ")
}
var str1 = str.dropRight(3)
str1 +=" ]"
println("AdaBoost_Model为:"+str1)
}
//确定弱分类器,最终的返回形式为:(切割点,误差率,Gx函数)
//其中Gx表示形式为:(cutpoint:Double,x:Double)=>label:Double
def findGmX(trainData:Array[LabeledPoint],cutpoint:Array[Double])={
//每个切割点,可以定义两个Gx函数
val Gx_1 = (cutpoint:Double,x:Double) => if(x > cutpoint) 1.0 else -1.0
val Gx_2 = (cutpoint:Double,x:Double) => if(x > cutpoint) -1.0 else 1.0
//e_1、e_2分别表示Gx_1和Gx_2的分类误差率
var e_1 = 0.0
var e_2 = 0.0
//保存每个切割点所对应的最优的分类函数Gx
val GxArr = ArrayBuffer[(Double,Double)=>Double]()
//保存每个切割点的最优分类函数GX所对应的误差率
val GxE = ArrayBuffer[Double]()
for(cp <- cutpoint){
for(td <- trainData){
//若正确分类,则deta=0,否则分类误差率加上改样本点的权重
val deta_1 = if(Gx_1(cp,td.features.toArray(0))!= td.label) td.features(1) else 0.0
e_1 += deta_1
val deta_2 = if(Gx_2(cp,td.features.toArray(0))!= td.label) td.features(1) else 0.0
e_2 += deta_2
}
if(e_1 >= e_2){GxArr.append(Gx_2);GxE.append(e_2)} else {GxArr.append(Gx_1);GxE.append(e_1)}
//注意e_1和e_2必须清零
e_1 = 0.0
e_2 = 0.0
}
//确定最优的弱学习器(切割点,误差率,Gx函数)
val Gx_E_Arr = ArrayBuffer[(Double,Double,(Double,Double)=>Double)]()
for(i<-0 until GxE.length) Gx_E_Arr.append((cutpoint(i),GxE(i),GxArr(i)))
val bestGx = Gx_E_Arr.sortBy(x=>x._2).take(1)(0)
//返回最优的弱学习器
bestGx
}
}
----------------------------- -------Result-----------------------------------------------------------
am:0.4236
em:0.30000000000000004
cutpoint:2.5
am:0.6496
em:0.21429619932719202
cutpoint:8.5
am:0.752
em:0.1818313875438102
cutpoint:5.5
AdaBoost_Model为:sign [0.4236*GmX(2.5) + 0.6496*GmX(8.5) + 0.752*GmX(5.5) ]