import breeze.linalg.DenseVector
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.SparkSession
import scala.collection.mutable.ArrayBuffer
/**
* @author XiaoTangBao
* @date 2019/3/6 10:13
* @version 1.0
* The original form of perceptron learning algorithm,For linearly separable data sets,
* the original form of the perceptron algorithm converges. After a finite number of iterations,
* a hyperplane can be found, and the data set is completely correctly divided.
*/
object PLA {
def main(args: Array[String]): Unit = {
//屏蔽部分日志
Logger.getLogger("org.apache.spark").setLevel(Level.ERROR)
val sparkSession = SparkSession.builder().master("local[4]").appName("PLA").getOrCreate()
//获取数据源----https://pan.baidu.com/s/17dK9fdGHzGY1SfI-s1pt6w
val data = sparkSession.sparkContext.textFile("G:\\mldata\\iris.txt")
val pddata = data.map(str => str.split('|')).map(arr =>(arr(0).toDouble,arr(1).toDouble,arr(2).toDouble,arr(3).toDouble,arr(4))).collect()
val Xi = ArrayBuffer[LabeledPoint]()
for(dt<- pddata) {
var label = -1
if(dt._5.equals("Iris-setosa")) label = 1
Xi.append(LabeledPoint(label,Vectors.dense(dt._1,dt._2,dt._3,dt._4)))
}
//待分类点为四维,定义初始的w,b,ata
var w = DenseVector(3.0,0.8,1.8,2.4)
var b = 20.0
var ata = 0.2
//基于SGD迭代求解最优w,b,ata
var outflag = true
while(outflag){
for(i<-0 until Xi.length){
var inflag = true
//针对当前分类错误点,不停的修改超平面,直至该点分类正确
while(inflag){
if(!judge(w,b,Xi(i))){
println("当前纠正:X"+(i+1))
w = ata * Xi(i).label * DenseVector(Xi(i).features.toArray) + w
b = b + ata * Xi(i).label
}else{
inflag = false
}
}
}
//所有测点都完全分类正确,则退出
var num = 0
for(i<-0 until Xi.length){
if(judge(w,b,Xi(i))) num +=1
}
if(num == Xi.length) outflag = false
}
println("训练结束")
println(w)
println(b)
}
//判断是否被正确分类
def judge(w:DenseVector[Double],b:Double,xi:LabeledPoint):Boolean = {
var flag = true
//(w dot DenseVector(xi.features.toArray)) 必须添加优先级(),不然报错
val fit = xi.label * ((w dot DenseVector(xi.features.toArray)) + b )
if(fit <=0) flag = false
flag
}
}
--------------------------------------------------------------------
当前纠正:X51
当前纠正:X51
当前纠正:X51
当前纠正:X1
当前纠正:X52
当前纠正:X1
当前纠正:X58
当前纠正:X1
当前纠正:X58
当前纠正:X1
当前纠正:X58
当前纠正:X1
当前纠正:X99
当前纠正:X1
训练结束
DenseVector(-1.7200000000000006, -0.1399999999999999, -3.760000000000001, 0.40000000000000024)
19.400000000000002
算法小白的第一次尝试---PLA(感知机算法)实现
最新推荐文章于 2024-01-24 17:30:31 发布