Spark2 Linear Regression线性回归案例(参数调优)

本文通过使用Apache Spark MLlib库中的ElasticNet回归方法,详细介绍了如何构建预测模型,并利用Train-Validation Split方法进行模型调优的过程。示例中以预测各州谋杀率为例,展示了特征组装、模型训练及评估的完整流程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

 回归正则化方法(Lasso,Ridge和ElasticNet)在高维和数据集变量之间多重共线性情况下运行良好。

 

数学上,ElasticNet被定义为L1和L2正则化项的凸组合:

通过适当设置α,ElasticNet包含L1和L2正则化作为特殊情况。例如,如果用参数α设置为1来训练线性回归模型,则其等价于Lasso模型。另一方面,如果α被设置为0,则训练的模型简化为ridge回归模型。 

RegParam:lambda>=0
ElasticNetParam:alpha in [0, 1]


导入包

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrameReader
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.DataFrameStatFunctions
import org.apache.spark.sql.functions._
 
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.regression.LinearRegression

导入样本数据

 

// Population人口,
// Income收入水平,
// Illiteracy文盲率,
// LifeExp,
// Murder谋杀率,
// HSGrad,
// Frost结霜天数(温度在冰点以下的平均天数) ,
// Area州面积
    val spark = SparkSession.builder().appName("Spark Linear Regression").config("spark.some.config.option""some-value").getOrCreate()
 
    // For implicit conversions like converting RDDs to DataFrames
    import spark.implicits._
 
    val dataList: List[(Double, Double, Double, Double, Double, Double, Double, Double)] = List(
      (361536242.169.0515.141.32050708),
      (36563151.569.3111.366.7152566432),
      (221245301.870.557.858.115113417),
      (211033781.970.6610.139.96551945),
      (2119851141.171.7110.362.620156361),
      (254148840.772.066.863.9166103766),
      (310053481.172.483.1561394862),
      (57948090.970.066.254.61031982),
      (827748151.370.6610.752.61154090),
      (49314091268.5413.940.66058073),
      (86849631.973.66.261.906425),
      (81341190.671.875.359.512682677),
      (1119751070.970.1410.352.612755748),
      (531344580.770.887.152.912236097),
      (286146280.572.562.35914055941),
      (228046690.672.584.559.911481787),
      (338737121.670.110.638.59539650),
      (380635452.868.7613.242.21244930),
      (105836940.770.392.754.716130920),
      (412252990.970.228.552.31019891),
      (581447551.171.833.358.51037826),
      (911147510.970.6311.152.812556817),
      (392146750.672.962.357.616079289),
      (234130982.468.0912.5415047296),
      (476742540.870.699.348.810868995),
      (74643470.670.56559.2155145587),
      (154445080.672.62.959.313976483),
      (59051490.569.0311.565.2188109889),
      (81242810.771.233.357.61749027),
      (733352371.170.935.252.51157521),
      (114436012.270.329.755.2120121412),
      (1807649031.470.5510.952.78247831),
      (544138751.869.2111.138.58048798),
      (63750870.872.781.450.318669273),
      (1073545610.870.827.453.212440975),
      (271539831.171.426.451.68268782),
      (228446600.672.134.2604496184),
      (118604449170.436.150.212644966),
      (93145581.371.92.446.41271049),
      (281636352.367.9611.637.86530225),
      (68141670.572.081.753.317275955),
      (417338211.770.111141.87041328),
      (1223741882.270.912.247.435262134),
      (120340220.672.94.567.313782096),
      (47239070.671.645.557.11689267),
      (498147011.470.089.547.88539780),
      (355948640.671.724.363.53266570),
      (179936171.469.486.741.610024070),
      (458944680.772.48354.514954464),
      (37645660.670.296.962.917397203))
 
    val data = dataList.toDF("Population""Income""Illiteracy""LifeExp""Murder""HSGrad""Frost""Area")

建立线性回归模型

 

val colArray = Array("Population""Income""Illiteracy""LifeExp""HSGrad""Frost""Area")
 
val assembler = new VectorAssembler().setInputCols(colArray).setOutputCol("features")
 
val vecDF: DataFrame = assembler.transform(data)
 
// 建立模型,预测谋杀率Murder
// 设置线性回归参数
val lr1 = new LinearRegression()
val lr2 = lr1.setFeaturesCol("features").setLabelCol("Murder").setFitIntercept(true)
// RegParam:正则化
val lr3 = lr2.setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
val lr = lr3
 
// Fit the model
val lrModel = lr.fit(vecDF)
 
// 输出模型全部参数
lrModel.extractParamMap()
// Print the coefficients and intercept for linear regression
println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
 
val predictions = lrModel.transform(vecDF)
predictions.selectExpr("Murder""round(prediction,1) as prediction").show
 
// Summarize the model over the training set and print out some metrics
val trainingSummary = lrModel.summary
println(s"numIterations: ${trainingSummary.totalIterations}")
println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}")
trainingSummary.residuals.show()
println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
println(s"r2: ${trainingSummary.r2}")

代码执行结果

 

// 输出模型全部参数
lrModel.extractParamMap()
res15: org.apache.spark.ml.param.ParamMap =
{
    linReg_2ba28140e39a-elasticNetParam: 0.8,
    linReg_2ba28140e39a-featuresCol: features,
    linReg_2ba28140e39a-fitIntercept: true,
    linReg_2ba28140e39a-labelCol: Murder,
    linReg_2ba28140e39a-maxIter: 10,
    linReg_2ba28140e39a-predictionCol: prediction,
    linReg_2ba28140e39a-regParam: 0.3,
    linReg_2ba28140e39a-solver: auto,
    linReg_2ba28140e39a-standardization: true,
    linReg_2ba28140e39a-tol: 1.0E-6
}
 
// Print the coefficients and intercept for linear regression
println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
Coefficients: [1.36662199778084E-4,0.0,1.1834384307116244,-1.4580829641757522,0.0,-0.010686434270049252,4.051355050528196E-6] Intercept: 109.589659881471
 
val predictions = lrModel.transform(vecDF)
predictions: org.apache.spark.sql.DataFrame = [Population: double, Income: double ... 8 more fields]
 
predictions.selectExpr("Murder""round(prediction,1) as prediction").show
+------+----------+
|Murder|prediction|
+------+----------+
|  15.1|      11.9|
|  11.3|      11.0|
|   7.8|       9.5|
|  10.1|       8.6|
|  10.3|       9.6|
|   6.8|       4.3|
|   3.1|       4.2|
|   6.2|       7.5|
|  10.7|       9.3|
|  13.9|      12.3|
|   6.2|       4.7|
|   5.3|       4.6|
|  10.3|       8.8|
|   7.1|       6.6|
|   2.3|       3.5|
|   4.5|       3.9|
|  10.6|       8.9|
|  13.2|      13.2|
|   2.7|       6.3|
|   8.5|       7.8|
+------+----------+
only showing top 20 rows
 
// Summarize the model over the training set and print out some metrics
val trainingSummary = lrModel.summary
trainingSummary: org.apache.spark.ml.regression.LinearRegressionTrainingSummary = org.apache.spark.ml.regression.LinearRegressionTrainingSummary@68a83d76
 
println(s"numIterations: ${trainingSummary.totalIterations}")
numIterations: 11
 
println(s"objectiveHistory: ${trainingSummary.objectiveHistory.toList}")
objectiveHistory: List(0.490000000000000160.39192428068090930.199080784269049460.19014534927519140.179818742560314050.178781730842862470.17876178169356070.178754318546616410.1
78747026371411960.178745122715686850.1787449876896829)
trainingSummary.residuals.show()
+--------------------+
|           residuals|
+--------------------+
|  3.2200068116713023|
|  0.2745518816306607|
| -1.6535887417767414|
|   1.485762696757325|
|  0.6509766532389172|
|   2.457688146554534|
| -1.0675250558261182|
| -1.2879164685248439|
|  1.3672723619868314|
|  1.6125000289597242|
|   1.532060517905248|
|  0.6931301635074645|
|  1.5163001982000175|
0.46227066807431605|
| -1.2044058248740273|
|  0.6032541157521649|
|     1.7201545753635|
|-0.01942937427384...|
|  -3.632947522687547|
|  0.7077675962948007|
+--------------------+
only showing top 20 rows
 
println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
RMSE: 1.6663615527314546
 
println(s"r2: ${trainingSummary.r2}")
r2: 0.7920794990832152

模型调优,用Train-Validation Split

 

val colArray = Array("Population""Income""Illiteracy""LifeExp""HSGrad""Frost""Area")
 
val vecDF: DataFrame = new VectorAssembler().setInputCols(colArray).setOutputCol("features").transform(data)
 
val Array(trainingDF, testDF) = vecDF.randomSplit(Array(0.90.1), seed = 12345)
 
// 建立模型,预测谋杀率Murder,设置线性回归参数
val lr = new LinearRegression().setFeaturesCol("features").setLabelCol("Murder").fit(trainingDF)
 
// 设置管道
val pipeline = new Pipeline().setStages(Array(lr))
 
// 建立参数网格
val paramGrid = new ParamGridBuilder().addGrid(lr.fitIntercept).addGrid(lr.elasticNetParam, Array(0.00.51.0)).addGrid(lr.maxIter, Array(10100)).build()
 
// 选择(prediction, true label),计算测试误差。
// 注意RegEvaluator.isLargerBetter,评估的度量值是大的好,还是小的好,系统会自动识别
val RegEvaluator = new RegressionEvaluator().setLabelCol(lr.getLabelCol).setPredictionCol(lr.getPredictionCol).setMetricName("rmse")
 
val trainValidationSplit = new TrainValidationSplit().setEstimator(pipeline).setEvaluator(RegEvaluator).setEstimatorParamMaps(paramGrid).setTrainRatio(0.8// 数据分割比例
 
// Run train validation split, and choose the best set of parameters.
val tvModel = trainValidationSplit.fit(trainingDF)
 
// 查看模型全部参数
tvModel.extractParamMap()
 
tvModel.getEstimatorParamMaps.length
tvModel.getEstimatorParamMaps.foreach { println } // 参数组合的集合
 
tvModel.getEvaluator.extractParamMap() // 评估的参数
 
tvModel.getEvaluator.isLargerBetter // 评估的度量值是大的好,还是小的好
 
tvModel.getTrainRatio
 
// 用最好的参数组合,做出预测
tvModel.transform(testDF).select("features""Murder""prediction").show()

调优代码执行结果

 

// 查看模型全部参数
tvModel.extractParamMap()
res45: org.apache.spark.ml.param.ParamMap =
{
    tvs_5de7d3dd1977-estimator: pipeline_062a1dffe557,
    tvs_5de7d3dd1977-estimatorParamMaps: [Lorg.apache.spark.ml.param.ParamMap;@60298de1,
    tvs_5de7d3dd1977-evaluator: regEval_05204824acb9,
    tvs_5de7d3dd1977-seed: -1772833110,
    tvs_5de7d3dd1977-trainRatio: 0.8
}
 
tvModel.getEstimatorParamMaps.length
res46: Int = 12
 
tvModel.getEstimatorParamMaps.foreach { println } // 参数组合的集合
{
    linReg_75628a5554b4-elasticNetParam: 0.0,
    linReg_75628a5554b4-fitIntercept: true,
    linReg_75628a5554b4-maxIter: 10
}
{
    linReg_75628a5554b4-elasticNetParam: 0.0,
    linReg_75628a5554b4-fitIntercept: true,
    linReg_75628a5554b4-maxIter: 100
}
{
    linReg_75628a5554b4-elasticNetParam: 0.0,
    linReg_75628a5554b4-fitIntercept: false,
    linReg_75628a5554b4-maxIter: 10
}
{
    linReg_75628a5554b4-elasticNetParam: 0.0,
    linReg_75628a5554b4-fitIntercept: false,
    linReg_75628a5554b4-maxIter: 100
}
{
    linReg_75628a5554b4-elasticNetParam: 0.5,
    linReg_75628a5554b4-fitIntercept: true,
    linReg_75628a5554b4-maxIter: 10
}
{
    linReg_75628a5554b4-elasticNetParam: 0.5,
    linReg_75628a5554b4-fitIntercept: true,
    linReg_75628a5554b4-maxIter: 100
}
{
    linReg_75628a5554b4-elasticNetParam: 0.5,
    linReg_75628a5554b4-fitIntercept: false,
    linReg_75628a5554b4-maxIter: 10
}
{
    linReg_75628a5554b4-elasticNetParam: 0.5,
    linReg_75628a5554b4-fitIntercept: false,
    linReg_75628a5554b4-maxIter: 100
}
{
    linReg_75628a5554b4-elasticNetParam: 1.0,
    linReg_75628a5554b4-fitIntercept: true,
    linReg_75628a5554b4-maxIter: 10
}
{
    linReg_75628a5554b4-elasticNetParam: 1.0,
    linReg_75628a5554b4-fitIntercept: true,
    linReg_75628a5554b4-maxIter: 100
}
{
    linReg_75628a5554b4-elasticNetParam: 1.0,
    linReg_75628a5554b4-fitIntercept: false,
    linReg_75628a5554b4-maxIter: 10
}
{
    linReg_75628a5554b4-elasticNetParam: 1.0,
    linReg_75628a5554b4-fitIntercept: false,
    linReg_75628a5554b4-maxIter: 100
}
 
tvModel.getEvaluator.extractParamMap() // 评估的参数
res48: org.apache.spark.ml.param.ParamMap =
{
    regEval_05204824acb9-labelCol: Murder,
    regEval_05204824acb9-metricName: rmse,
    regEval_05204824acb9-predictionCol: prediction
}
 
tvModel.getEvaluator.isLargerBetter // 评估的度量值是大的好,还是小的好
res49: Boolean = false
 
tvModel.getTrainRatio
res50: Double = 0.8
 
tvModel.transform(testDF).select("features""Murder""prediction").show()
+--------------------+------+------------------+
|            features|Murder|        prediction|
+--------------------+------+------------------+
|[1058.0,3694.0,0....|   2.76.917232043935343|
|[2341.0,3098.0,2....|  12.5|14.760329005533478|
|[472.0,3907.0,0.6...|   5.54.182074651181182|
|[812.0,4281.0,0.7...|   3.34.915905572667441|
|[2816.0,3635.0,2....|  11.6|14.219231061596304|
|[4589.0,4468.0,0....|   3.03.483554528704758|
+--------------------+------+------------------+

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值