原始数据tableData结构
root
|-- user_id: integer (nullable = false)
|-- city: string (nullable = true)
|-- category: integer (nullable = false)
|-- from_place: string (nullable = true)
|-- to_place: string (nullable = true)
|-- isclick: integer (nullable = false)
|-- future_day: integer (nullable = false)
|-- banner_min_time: integer (nullable = false)
|-- banner_min_price: double (nullable = false)
|-- start_city_id: integer (nullable = false)
|-- start_city_name: string (nullable = true)
|-- end_city_id: integer (nullable = false)
|-- end_city_name: string (nullable = true)
|-- page_train: integer (nullable = false)
|-- page_flight: integer (nullable = false)
|-- page_bus: integer (nullable = false)
|-- page_ship: integer (nullable = false)
|-- page_transfer: integer (nullable = false)
|-- start_end_distance: double (nullable = false)
|-- total_transport: integer (nullable = false)
|-- high_railway_percent: double (nullable = false)
|-- avg_time: integer (nullable = false)
|-- min_time: integer (nullable = false)
|-- avg_price: double (nullable = false)
|-- min_price: double (nullable = false)
|-- label_05060801: integer (nullable = false)
|-- label_05060701: integer (nullable = false)
|-- label_05060601: integer (nullable = false)
|-- label_02050601: integer (nullable = false)
|-- label_02050501: integer (nullable = false)
|-- label_02050401: integer (nullable = false)
|-- is_match_category: integer (nullable = false)
|-- train_consumer_prefer: integer (nullable = false)
|-- flight_consumer_prefer: integer (nullable = false)
|-- bus_consumer_prefer: integer (nullable = false)
|-- create_date: timestamp (nullable = true)
+---------+----+--------+----------+--------+-------+----------+---------------+----------------+-------------+---------------+-----------+-------------+----------+-----------+--------+---------+-------------+------------------+---------------+--------------------+--------+--------+---------+---------+--------------+--------------+--------------+--------------+--------------+--------------+-----------------+---------------------+----------------------+-------------------+---------------------+
|user_id |city|category|from_place|to_place|isclick|future_day|banner_min_time|banner_min_price|start_city_id|start_city_name|end_city_id|end_city_name|page_train|page_flight|page_bus|page_ship|page_transfer|start_end_distance|total_transport|high_railway_percent|avg_time|min_time|avg_price|min_price|label_05060801|label_05060701|label_05060601|label_02050601|label_02050501|label_02050401|is_match_category|train_consumer_prefer|flight_consumer_prefer|bus_consumer_prefer|create_date |
+---------+----+--------+----------+--------+-------+----------+---------------+----------------+-------------+---------------+-----------+-------------+----------+-----------+--------+---------+-------------+------------------+---------------+--------------------+--------+--------+---------+---------+--------------+--------------+--------------+--------------+--------------+--------------+-----------------+---------------------+----------------------+-------------------+---------------------+
|100523214|桂林市 |2 |桂林 |南昌 |0 |1 |27900 |1850.0 |2099 |桂林市 |1235 |南昌市 |6 |0 |0 |0 |4 |681000.0 |16 |0.45 |34507 |17760 |220.0 |112.0 |0 |0 |0 |0 |0 |1 |0 |0 |0 |0 |2019-02-20 00:00:00.0|
|100523214|桂林市 |2 |桂林 |株洲 |0 |1 |51900 |1410.0 |2099 |桂林市 |1805 |株洲市 |6 |0 |0 |0 |4 |412000.0 |21 |0.47 |19262 |10680 |117.0 |69.0 |0 |0 |0 |0 |0 |1 |0 |0 |0 |0 |2019-02-20 00:00:00.0|
|102338191|广州市 |2 |广州南 |滕州东 |0 |1 |7500 |1424.0 |1932 |广州市 |1384 |滕州市 |3 |0 |0 |0 |2 |1386000.0 |2 |0.0 |81600 |81600 |229.0 |229.0 |0 |0 |0 |0 |1 |1 |0 |0 |0 |0 |2019-02-20 00:00:00.0|
|102338191|广州市 |2 |广州南 |滕州东 |0 |13 |7500 |781.0 |1932 |广州市 |1384 |滕州市 |3 |0 |0 |0 |2 |1386000.0 |2 |0.0 |81600 |81600 |229.0 |229.0 |0 |0 |0 |0 |1 |1 |0 |0 |0 |0 |2019-02-20 00:00:00.0|
|103053718|朝阳市 |3 |赤峰 |通辽 |0 |1 |14400 |80.0 |371 |赤峰市 |384 |通辽市 |11 |2 |0 |0 |9 |314000.0 |12 |0.0 |21139 |13440 |41.0 |23.5 |0 |0 |1 |0 |0 |1 |0 |2 |0 |0 |2019-02-20 00:00:00.0|
|103053718|朝阳市 |3 |赤峰 |霍林郭勒 |0 |1 |14400 |162.0 |371 |赤峰市 |392 |霍林郭勒市 |11 |2 |0 |0 |9 |370000.0 |-99 |-99.0 |-99 |-99 |-99.0 |-99.0 |0 |0 |1 |0 |0 |1 |0 |2 |0 |0 |2019-02-20 00:00:00.0|
|103053718|朝阳市 |3 |杨杖子 |叶柏寿 |0 |1 |9000 |6.5 |569 |凌源市 |566 |建平县 |6 |0 |0 |0 |4 |26000.0 |11 |0.0 |4669 |2340 |7.0 |3.0 |0 |0 |1 |0 |0 |1 |0 |2 |0 |0 |2019-02-20 00:00:00.0|
|103425337|null|2 |大连 |长春 |0 |0 |6300 |380.0 |477 |大连市 |578 |长春市 |3 |0 |0 |0 |1 |627000.0 |44 |0.57 |20231 |9600 |211.0 |75.0 |0 |1 |0 |0 |0 |1 |0 |3 |0 |0 |2019-02-20 00:00:00.0|
|105705123|邢台市 |7 |清河城 |北京 |0 |1 |60 |32.0 |117 |清河县 |1 |北京市 |22 |0 |0 |0 |8 |325000.0 |14 |0.0 |17432 |14520 |54.0 |50.5 |0 |0 |0 |0 |1 |1 |0 |2 |1 |0 |2019-02-20 00:00:00.0|
|110193722|阜新市 |2 |阜新 |北京 |0 |0 |4500 |2210.0 |533 |阜新市 |1 |北京市 |2 |0 |0 |0 |1 |501000.0 |3 |0.0 |46220 |35820 |79.0 |72.0 |0 |1 |0 |1 |1 |1 |0 |4 |1 |1 |2019-02-20 00:00:00.0|
|110193722|阜新市 |3 |阜新 |北京 |0 |1 |25200 |180.0 |533 |阜新市 |1 |北京市 |2 |0 |0 |0 |1 |501000.0 |3 |0.0 |46220 |35820 |79.0 |72.0 |0 |1 |0 |1 |1 |1 |0 |4 |1 |1 |2019-02-20 00:00:00.0|
|110193722|阜新市 |3 |阜新 |北京 |0 |2 |25200 |180.0 |533 |阜新市 |1 |北京市 |2 |0 |0 |0 |1 |501000.0 |3 |0.0 |46220 |35820 |79.0 |72.0 |0 |1 |0 |1 |1 |1 |0 |4 |1 |1 |2019-02-20 00:00:00.0|
|110234656|北京市 |3 |保定 |石家庄 |0 |1 |7200 |38.0 |121 |保定市 |36 |石家庄市 |23 |0 |0 |0 |8 |124000.0 |195 |0.46 |4423 |2160 |44.0 |10.5 |0 |1 |0 |1 |1 |1 |0 |4 |2 |2 |2019-02-20 00:00:00.0|
|110234656|北京市 |3 |保定 |天津 |0 |1 |9000 |61.0 |121 |保定市 |18 |天津市 |23 |0 |0 |0 |8 |152000.0 |64 |0.66 |7423 |3180 |55.0 |23.5 |0 |1 |0 |1 |1 |1 |0 |4 |2 |2 |2019-02-20 00:00:00.0|
|110234656|北京市 |3 |石家庄 |青岛 |0 |1 |27000 |195.0 |36 |石家庄市 |1358 |青岛市 |23 |0 |0 |0 |8 |565000.0 |22 |0.63 |11400 |11400 |205.0 |81.0 |0 |1 |0 |1 |1 |1 |0 |4 |2 |2 |2019-02-20 00:00:00.0|
|110234656|北京市 |2 |天津 |青岛 |0 |1 |4200 |383.0 |18 |天津市 |1358 |青岛市 |23 |0 |0 |0 |8 |437000.0 |32 |0.47 |4581 |4200 |199.0 |91.0 |0 |1 |0 |1 |1 |1 |0 |4 |2 |2 |2019-02-20 00:00:00.0|
|110234656|北京市 |3 |保定 |青岛 |0 |1 |30600 |196.0 |121 |保定市 |1358 |青岛市 |23 |0 |0 |0 |8 |535000.0 |-99 |-99.0 |-99 |-99 |-99.0 |-99.0 |0 |1 |0 |1 |1 |1 |0 |4 |2 |2 |2019-02-20 00:00:00.0|
|110274557|邵阳市 |3 |衡阳 |隆回 |0 |1 |12600 |85.0 |1821 |衡阳市 |1841 |隆回县 |2 |0 |0 |0 |1 |154000.0 |12 |1.0 |4020 |3900 |60.0 |54.0 |0 |0 |0 |0 |0 |0 |0 |0 |0 |0 |2019-02-20 00:00:00.0|
|115392135|株洲市 |3 |长沙 |大丰 |0 |0 |3600 |20.0 |1795 |长沙市 |1805 |株洲市 |4 |0 |0 |0 |1 |48000.0 |156 |0.07 |2752 |420 |20.0 |4.0 |0 |0 |0 |0 |1 |1 |0 |2 |0 |0 |2019-02-20 00:00:00.0|
|115392135|株洲市 |26 |大丰 |长沙 |0 |0 |7200 |75.0 |1805 |株洲市 |1795 |长沙市 |11 |0 |0 |0 |1 |48000.0 |161 |0.09 |2652 |480 |19.0 |5.0 |0 |0 |0 |0 |1 |1 |0 |2 |0 |0 |2019-02-20 00:00:00.0|
+---------+----+--------+----------+--------+-------+----------+---------------+----------------+-------------+---------------+-----------+-------------+----------+-----------+--------+---------+-------------+------------------+---------------+--------------------+--------+--------+---------+---------+--------------+--------------+--------------+--------------+--------------+--------------+-----------------+---------------------+----------------------+-------------------+---------------------+
only showing top 20 rows
islick是验证label,其他字段是特征集。
构建features
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
。。。
VectorAssembler assembler = new VectorAssembler().setInputCols(new String[]{
"category", "future_day", "banner_min_time","banner_min_price",
"page_train", "page_flight", "page_bus", "page_transfer",
"start_end_distance", "total_transport", "high_railway_percent", "avg_time", "min_time",
"avg_price", "min_price",
"label_05060801", "label_05060701", "label_05060601", "label_02050601", "label_02050501", "label_02050401",
"is_match_category", "train_consumer_prefer", "flight_consumer_prefer", "bus_consumer_prefer"
}).setOutputCol("features");
Dataset<Row> trainData = assembler.transform(tableData);
LogisticRegression lr = new LogisticRegression()
.setElasticNetParam(0)
.setLabelCol("isclick")
.setFeaturesCol("features")
.setMaxIter(60)
.setRegParam(0.1);
ParamMap[] paramGrid = new ParamGridBuilder().
addGrid(lr.maxIter(), new int[]{50, 80})
.addGrid(lr.regParam(), new double[]{ 0.005, 0.01})
.addGrid(lr.elasticNetParam(), new double[]{0, 0.1})
.build();
BinaryClassificationEvaluator binaryClassificationEvaluator = new BinaryClassificationEvaluator().setLabelCol("isclick");
CrossValidator crossValidator = new CrossValidator()
.setEstimator(lr)
.setEvaluator(binaryClassificationEvaluator)
.setEstimatorParamMaps(paramGrid).setNumFolds(3);
CrossValidatorModel cvModel = crossValidator.fit(trainData);
这里我采用LR逻辑回归,Fold设置为3,采用auc评估标准。CrossValidatorModel 就是最佳交叉验证模型。
查看每一次交叉验证的结果
System.out.println("------------------------");
ParamMap[] paramMaps = cvModel.getEstimatorParamMaps();
double[] rocArea = cvModel.avgMetrics();
for (int i = 0; i < paramMaps.length; i++) {
System.out.println("------------" + i + "-------------");
System.out.println("param:" + paramMaps[i]);
System.out.println("rocArea:" + rocArea[i]);
}
结果显示
------------0-------------
param:{
logreg_54489c6b4a15-elasticNetParam: 0.0,
logreg_54489c6b4a15-maxIter: 50,
logreg_54489c6b4a15-regParam: 0.005
}
rocArea:0.7561587339955449
------------1-------------
param:{
logreg_54489c6b4a15-elasticNetParam: 0.1,
logreg_54489c6b4a15-maxIter: 50,
logreg_54489c6b4a15-regParam: 0.005
}
rocArea:0.7601808941246924
------------2-------------
param:{
logreg_54489c6b4a15-elasticNetParam: 0.0,
logreg_54489c6b4a15-maxIter: 50,
logreg_54489c6b4a15-regParam: 0.01
}
rocArea:0.7458205940087932
------------3-------------
param:{
logreg_54489c6b4a15-elasticNetParam: 0.1,
logreg_54489c6b4a15-maxIter: 50,
logreg_54489c6b4a15-regParam: 0.01
}
rocArea:0.7519450829980202
------------4-------------
param:{
logreg_54489c6b4a15-elasticNetParam: 0.0,
logreg_54489c6b4a15-maxIter: 80,
logreg_54489c6b4a15-regParam: 0.005
}
rocArea:0.7561587339955449
------------5-------------
param:{
logreg_54489c6b4a15-elasticNetParam: 0.1,
logreg_54489c6b4a15-maxIter: 80,
logreg_54489c6b4a15-regParam: 0.005
}
rocArea:0.7600964897398843
------------6-------------
param:{
logreg_54489c6b4a15-elasticNetParam: 0.0,
logreg_54489c6b4a15-maxIter: 80,
logreg_54489c6b4a15-regParam: 0.01
}
rocArea:0.7458205940087932
------------7-------------
param:{
logreg_54489c6b4a15-elasticNetParam: 0.1,
logreg_54489c6b4a15-maxIter: 80,
logreg_54489c6b4a15-regParam: 0.01
}
rocArea:0.7519466667404875
查看最佳模型的参数
LogisticRegressionModel logisticRegressionModel = (LogisticRegressionModel) (cvModel.bestModel());
System.out.println("---------coefficients is :------------------------------");
System.out.println(logisticRegressionModel.coefficients());
System.out.println("---------intercept is :------------------------------");
System.out.println(logisticRegressionModel.intercept());
System.out.println("------------best param is :------------------------------");
System.out.println("ElasticNetParam:" + logisticRegressionModel.getElasticNetParam());
System.out.println("RegParam:" + logisticRegressionModel.getRegParam());
System.out.println("MaxIter:" + logisticRegressionModel.getMaxIter());
结果显示
---------coefficients is :------------------------------
[-0.020093624027311234,-0.023840313521126238,0.0,-4.870154374284758E-4,0.0,0.065071797632406,0.07375116647372162,0.0,2.1894539394004693E-7,-0.0018188669432944152,-0.0024137353437925675,-7.326062556226837E-6,-6.036644002424079E-7,4.186151682563356E-4,0.0,-0.4216808408256718,-0.25707113499256773,0.06709830247479415,0.0,0.009612099202465701,0.0,0.0,-0.08075082386845575,0.025511061645731456,-0.15643820797235122]
---------intercept is :------------------------------
-3.092110054323254
------------best param is :------------------------------
ElasticNetParam:0.1
RegParam:0.005
MaxIter:50
这样就成功打通spark逻辑回归的交叉验证数据。