Spark MLlib 课堂学习笔记 - 逻辑回归

本文介绍如何使用Spark MLlib中的逻辑回归算法,包括SGD梯度下降法和LBFGS两种实现方式,通过交通事故数据集的实际操作,展示了从数据准备到模型训练及预测的全过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

关于逻辑回归的算法原理Spark官方文档里有说明,另外网上也有中文翻译文档可参考。本笔记是学习MLlib的辑回归API使用时一道练习题记录,通过这道练习,可以掌握基本使用。MLLib提供了两种算法实现,分别是SGD梯度下降法和LBFGS。

1. 数据文件

交通事故的统计文件,四列,accident(去年是否出过事故,1表示出过事故,0表示没有),age(年龄 数值型),vision(视力状况,分类型,1表示好,0表示有问题),drive(驾车教育,分类型,1表示参加过驾车教育,0表示没有)。第1列是因变量,其它3列是特征。这是一个用空格分隔的文本文件,要使用MLLib算法库,首先要读文件并转成LabeledPoint数据类型的RDD。


[plain] view plain copy

  1. 1 17 1 1  

  2. 1 44 0 0  

  3. 1 48 1 0  

  4. 1 55 0 0  

  5. 1 75 1 1  

  6. 0 35 0 1  

  7. 0 42 1 1  

  8. 0 57 0 0  

  9. 0 28 0 1  

  10. 0 20 0 1  

  11. 0 38 1 0  

  12. 0 45 0 1  

  13. 0 47 1 1  

  14. 0 52 0 0  

  15. 0 55 0 1  

  16. 1 68 1 0  

  17. 1 18 1 0  

  18. 1 68 0 0  

  19. 1 48 1 1  

  20. 1 17 0 0  

  21. 1 70 1 1  

  22. 1 72 1 0  

  23. 1 35 0 1  

  24. 1 19 1 0  

  25. 1 62 1 0  

  26. 0 39 1 1  

  27. 0 40 1 1  

  28. 0 55 0 0  

  29. 0 68 0 1  

  30. 0 25 1 0  

  31. 0 17 0 0  

  32. 0 45 0 1  

  33. 0 44 0 1  

  34. 0 67 0 0  

  35. 0 55 0 1  

  36. 1 61 1 0  

  37. 1 19 1 0  

  38. 1 69 0 0  

  39. 1 23 1 1  

  40. 1 19 0 0  

  41. 1 72 1 1  

  42. 1 74 1 0  

  43. 1 31 0 1  

  44. 1 16 1 0  

  45. 1 61 1 0  

2. SGD算法


[plain] view plain copy

  1. package classify  

  2.   

  3. /*  

  4. accident.txt  

  5. accident(去年是否出过事故,1表示出过事故,0表示没有)  

  6. age(年龄 数值型)  

  7. vision(视力状况,分类型,1表示好,0表示有问题)  

  8. drive(驾车教育,分类型,1表示参加过驾车教育,0表示没有)  

  9.  */  

  10. import org.apache.spark.mllib.linalg.{Vector, Vectors}  

  11. import org.apache.spark.mllib.regression.LabeledPoint  

  12. import org.apache.spark.mllib.classification.LogisticRegressionWithSGD  

  13. import org.apache.spark.{SparkConf, SparkContext}  

  14.   

  15. object LogisticSGD {  

  16.   

  17.   def parseLine(line: String): LabeledPoint = {  

  18.     val parts = line.split(" ")  

  19.     val vd: Vector = Vectors.dense(parts(1).toDouble, parts(2).toDouble, parts(3).toDouble)  

  20.     return LabeledPoint(parts(0).toDouble, vd )  

  21.   }  

  22.   

  23.   

  24.   def main(args: Array[String]){  

  25.     val conf = new SparkConf().setMaster(args(0)).setAppName("LogisticSGD")  

  26.     val sc = new SparkContext(conf)  

  27.     val data =  sc.textFile(args(1)).map(parseLine(_))  

  28.   

  29.     val splits = data.randomSplit(Array(0.6, 0.4), seed=11L)  

  30.     val trainData = splits(0)  

  31.     val testData = splits(1)  

  32.   

  33.     val model = LogisticRegressionWithSGD.train(trainData, 50)  

  34.   

  35.     println(model.weights.size)  

  36.     println(model.weights)  

  37.     println(model.weights.toArray.filter(_ != 0).size)  

  38.   

  39.     val predictionAndLabel = testData.map(p => (model.predict(p.features), p.label))  

  40.   

  41.     predictionAndLabel.foreach(println)  

  42.   

  43.   }  

  44. }  

parseLine函数将文本文件的每一行转成一个LabeledPoint数据类型,randomSplit用例把数据集分成训练和测试两部分。val model = LogisticRegressionWithSGD.train(trainData, 50) 执行训练并得到模型,这里的50为迭代次数。val predictionAndLabel = testData.map(p => (model.predict(p.features), p.label))中的model.predict执行预测,testData.map测试数据集的特征值传递给model去预测,并将预测值与原有的label合并形成一个新的map。

3. LBFGS算法


[plain] view plain copy

  1. package classify  

  2.   

  3. import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS  

  4. import org.apache.spark.{SparkConf, SparkContext}  

  5. import org.apache.spark.mllib.linalg.{Vector, Vectors}  

  6. import org.apache.spark.mllib.regression.LabeledPoint  

  7.   

  8. object LogisticLBFGS {  

  9.   

  10.   def parseLine(line: String): LabeledPoint = {  

  11.     val parts = line.split(" ")  

  12.     val vd: Vector = Vectors.dense(parts(1).toDouble, parts(2).toDouble, parts(3).toDouble)  

  13.     return LabeledPoint(parts(0).toDouble, vd )  

  14.   }  

  15.   

  16.   def main(args: Array[String]){  

  17.     val conf = new SparkConf().setMaster(args(0)).setAppName("LogisticLBFGS")  

  18.     val sc = new SparkContext(conf)  

  19.     val data =  sc.textFile(args(1)).map(parseLine(_))  

  20.   

  21.     val splits = data.randomSplit(Array(0.6, 0.4), seed=11L)  

  22.     val trainData = splits(0)  

  23.     val testData = splits(1)  

  24.   

  25.     val model = new LogisticRegressionWithLBFGS().setNumClasses(2).run(trainData)  

  26.   

  27.     println(model.weights.size)  

  28.     println(model.weights)  

  29.     println(model.weights.toArray.filter(_ != 0).size)  

  30.   

  31.     val prediction = testData.map(p => (model.predict(p.features), p.label))  

  32.   

  33.     //println(prediction)  

  34.     prediction.foreach(println)  

  35.   

  36.   }  

  37. }  


val model = new LogisticRegressionWithLBFGS().setNumClasses(2).run(trainData)中的setNumClasses(2)设置分类数。

对于这个列子,LBFGS的效果比SGD的效果好。


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值