Spark-ml交叉验证demo

本文介绍了一个基于Spark的数据处理流程,使用逻辑回归(Logistic Regression)进行预测建模,并通过交叉验证(Cross Validation)来选择最优参数的过程。具体涉及了特征组装、模型训练、参数网格搜索等内容。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

原始数据tableData结构

root
 |-- user_id: integer (nullable = false)
 |-- city: string (nullable = true)
 |-- category: integer (nullable = false)
 |-- from_place: string (nullable = true)
 |-- to_place: string (nullable = true)
 |-- isclick: integer (nullable = false)
 |-- future_day: integer (nullable = false)
 |-- banner_min_time: integer (nullable = false)
 |-- banner_min_price: double (nullable = false)
 |-- start_city_id: integer (nullable = false)
 |-- start_city_name: string (nullable = true)
 |-- end_city_id: integer (nullable = false)
 |-- end_city_name: string (nullable = true)
 |-- page_train: integer (nullable = false)
 |-- page_flight: integer (nullable = false)
 |-- page_bus: integer (nullable = false)
 |-- page_ship: integer (nullable = false)
 |-- page_transfer: integer (nullable = false)
 |-- start_end_distance: double (nullable = false)
 |-- total_transport: integer (nullable = false)
 |-- high_railway_percent: double (nullable = false)
 |-- avg_time: integer (nullable = false)
 |-- min_time: integer (nullable = false)
 |-- avg_price: double (nullable = false)
 |-- min_price: double (nullable = false)
 |-- label_05060801: integer (nullable = false)
 |-- label_05060701: integer (nullable = false)
 |-- label_05060601: integer (nullable = false)
 |-- label_02050601: integer (nullable = false)
 |-- label_02050501: integer (nullable = false)
 |-- label_02050401: integer (nullable = false)
 |-- is_match_category: integer (nullable = false)
 |-- train_consumer_prefer: integer (nullable = false)
 |-- flight_consumer_prefer: integer (nullable = false)
 |-- bus_consumer_prefer: integer (nullable = false)
 |-- create_date: timestamp (nullable = true)

+---------+----+--------+----------+--------+-------+----------+---------------+----------------+-------------+---------------+-----------+-------------+----------+-----------+--------+---------+-------------+------------------+---------------+--------------------+--------+--------+---------+---------+--------------+--------------+--------------+--------------+--------------+--------------+-----------------+---------------------+----------------------+-------------------+---------------------+
|user_id  |city|category|from_place|to_place|isclick|future_day|banner_min_time|banner_min_price|start_city_id|start_city_name|end_city_id|end_city_name|page_train|page_flight|page_bus|page_ship|page_transfer|start_end_distance|total_transport|high_railway_percent|avg_time|min_time|avg_price|min_price|label_05060801|label_05060701|label_05060601|label_02050601|label_02050501|label_02050401|is_match_category|train_consumer_prefer|flight_consumer_prefer|bus_consumer_prefer|create_date          |
+---------+----+--------+----------+--------+-------+----------+---------------+----------------+-------------+---------------+-----------+-------------+----------+-----------+--------+---------+-------------+------------------+---------------+--------------------+--------+--------+---------+---------+--------------+--------------+--------------+--------------+--------------+--------------+-----------------+---------------------+----------------------+-------------------+---------------------+
|100523214|桂林市 |2       |桂林        |南昌      |0      |1         |27900          |1850.0          |2099         |桂林市            |1235       |南昌市          |6         |0          |0       |0        |4            |681000.0          |16             |0.45                |34507   |17760   |220.0    |112.0    |0             |0             |0             |0             |0             |1             |0                |0                    |0                     |0                  |2019-02-20 00:00:00.0|
|100523214|桂林市 |2       |桂林        |株洲      |0      |1         |51900          |1410.0          |2099         |桂林市            |1805       |株洲市          |6         |0          |0       |0        |4            |412000.0          |21             |0.47                |19262   |10680   |117.0    |69.0     |0             |0             |0             |0             |0             |1             |0                |0                    |0                     |0                  |2019-02-20 00:00:00.0|
|102338191|广州市 |2       |广州南       |滕州东     |0      |1         |7500           |1424.0          |1932         |广州市            |1384       |滕州市          |3         |0          |0       |0        |2            |1386000.0         |2              |0.0                 |81600   |81600   |229.0    |229.0    |0             |0             |0             |0             |1             |1             |0                |0                    |0                     |0                  |2019-02-20 00:00:00.0|
|102338191|广州市 |2       |广州南       |滕州东     |0      |13        |7500           |781.0           |1932         |广州市            |1384       |滕州市          |3         |0          |0       |0        |2            |1386000.0         |2              |0.0                 |81600   |81600   |229.0    |229.0    |0             |0             |0             |0             |1             |1             |0                |0                    |0                     |0                  |2019-02-20 00:00:00.0|
|103053718|朝阳市 |3       |赤峰        |通辽      |0      |1         |14400          |80.0            |371          |赤峰市            |384        |通辽市          |11        |2          |0       |0        |9            |314000.0          |12             |0.0                 |21139   |13440   |41.0     |23.5     |0             |0             |1             |0             |0             |1             |0                |2                    |0                     |0                  |2019-02-20 00:00:00.0|
|103053718|朝阳市 |3       |赤峰        |霍林郭勒    |0      |1         |14400          |162.0           |371          |赤峰市            |392        |霍林郭勒市        |11        |2          |0       |0        |9            |370000.0          |-99            |-99.0               |-99     |-99     |-99.0    |-99.0    |0             |0             |1             |0             |0             |1             |0                |2                    |0                     |0                  |2019-02-20 00:00:00.0|
|103053718|朝阳市 |3       |杨杖子       |叶柏寿     |0      |1         |9000           |6.5             |569          |凌源市            |566        |建平县          |6         |0          |0       |0        |4            |26000.0           |11             |0.0                 |4669    |2340    |7.0      |3.0      |0             |0             |1             |0             |0             |1             |0                |2                    |0                     |0                  |2019-02-20 00:00:00.0|
|103425337|null|2       |大连        |长春      |0      |0         |6300           |380.0           |477          |大连市            |578        |长春市          |3         |0          |0       |0        |1            |627000.0          |44             |0.57                |20231   |9600    |211.0    |75.0     |0             |1             |0             |0             |0             |1             |0                |3                    |0                     |0                  |2019-02-20 00:00:00.0|
|105705123|邢台市 |7       |清河城       |北京      |0      |1         |60             |32.0            |117          |清河县            |1          |北京市          |22        |0          |0       |0        |8            |325000.0          |14             |0.0                 |17432   |14520   |54.0     |50.5     |0             |0             |0             |0             |1             |1             |0                |2                    |1                     |0                  |2019-02-20 00:00:00.0|
|110193722|阜新市 |2       |阜新        |北京      |0      |0         |4500           |2210.0          |533          |阜新市            |1          |北京市          |2         |0          |0       |0        |1            |501000.0          |3              |0.0                 |46220   |35820   |79.0     |72.0     |0             |1             |0             |1             |1             |1             |0                |4                    |1                     |1                  |2019-02-20 00:00:00.0|
|110193722|阜新市 |3       |阜新        |北京      |0      |1         |25200          |180.0           |533          |阜新市            |1          |北京市          |2         |0          |0       |0        |1            |501000.0          |3              |0.0                 |46220   |35820   |79.0     |72.0     |0             |1             |0             |1             |1             |1             |0                |4                    |1                     |1                  |2019-02-20 00:00:00.0|
|110193722|阜新市 |3       |阜新        |北京      |0      |2         |25200          |180.0           |533          |阜新市            |1          |北京市          |2         |0          |0       |0        |1            |501000.0          |3              |0.0                 |46220   |35820   |79.0     |72.0     |0             |1             |0             |1             |1             |1             |0                |4                    |1                     |1                  |2019-02-20 00:00:00.0|
|110234656|北京市 |3       |保定        |石家庄     |0      |1         |7200           |38.0            |121          |保定市            |36         |石家庄市         |23        |0          |0       |0        |8            |124000.0          |195            |0.46                |4423    |2160    |44.0     |10.5     |0             |1             |0             |1             |1             |1             |0                |4                    |2                     |2                  |2019-02-20 00:00:00.0|
|110234656|北京市 |3       |保定        |天津      |0      |1         |9000           |61.0            |121          |保定市            |18         |天津市          |23        |0          |0       |0        |8            |152000.0          |64             |0.66                |7423    |3180    |55.0     |23.5     |0             |1             |0             |1             |1             |1             |0                |4                    |2                     |2                  |2019-02-20 00:00:00.0|
|110234656|北京市 |3       |石家庄       |青岛      |0      |1         |27000          |195.0           |36           |石家庄市           |1358       |青岛市          |23        |0          |0       |0        |8            |565000.0          |22             |0.63                |11400   |11400   |205.0    |81.0     |0             |1             |0             |1             |1             |1             |0                |4                    |2                     |2                  |2019-02-20 00:00:00.0|
|110234656|北京市 |2       |天津        |青岛      |0      |1         |4200           |383.0           |18           |天津市            |1358       |青岛市          |23        |0          |0       |0        |8            |437000.0          |32             |0.47                |4581    |4200    |199.0    |91.0     |0             |1             |0             |1             |1             |1             |0                |4                    |2                     |2                  |2019-02-20 00:00:00.0|
|110234656|北京市 |3       |保定        |青岛      |0      |1         |30600          |196.0           |121          |保定市            |1358       |青岛市          |23        |0          |0       |0        |8            |535000.0          |-99            |-99.0               |-99     |-99     |-99.0    |-99.0    |0             |1             |0             |1             |1             |1             |0                |4                    |2                     |2                  |2019-02-20 00:00:00.0|
|110274557|邵阳市 |3       |衡阳        |隆回      |0      |1         |12600          |85.0            |1821         |衡阳市            |1841       |隆回县          |2         |0          |0       |0        |1            |154000.0          |12             |1.0                 |4020    |3900    |60.0     |54.0     |0             |0             |0             |0             |0             |0             |0                |0                    |0                     |0                  |2019-02-20 00:00:00.0|
|115392135|株洲市 |3       |长沙        |大丰      |0      |0         |3600           |20.0            |1795         |长沙市            |1805       |株洲市          |4         |0          |0       |0        |1            |48000.0           |156            |0.07                |2752    |420     |20.0     |4.0      |0             |0             |0             |0             |1             |1             |0                |2                    |0                     |0                  |2019-02-20 00:00:00.0|
|115392135|株洲市 |26      |大丰        |长沙      |0      |0         |7200           |75.0            |1805         |株洲市            |1795       |长沙市          |11        |0          |0       |0        |1            |48000.0           |161            |0.09                |2652    |480     |19.0     |5.0      |0             |0             |0             |0             |1             |1             |0                |2                    |0                     |0                  |2019-02-20 00:00:00.0|
+---------+----+--------+----------+--------+-------+----------+---------------+----------------+-------------+---------------+-----------+-------------+----------+-----------+--------+---------+-------------+------------------+---------------+--------------------+--------+--------+---------+---------+--------------+--------------+--------------+--------------+--------------+--------------+-----------------+---------------------+----------------------+-------------------+---------------------+
only showing top 20 rows

islick是验证label,其他字段是特征集。
构建features

import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

。。。
        VectorAssembler assembler = new VectorAssembler().setInputCols(new String[]{
                "category", "future_day", "banner_min_time","banner_min_price",
                "page_train", "page_flight", "page_bus", "page_transfer",
                "start_end_distance", "total_transport", "high_railway_percent", "avg_time", "min_time",
                "avg_price", "min_price",
                "label_05060801", "label_05060701", "label_05060601", "label_02050601", "label_02050501", "label_02050401",
                "is_match_category", "train_consumer_prefer", "flight_consumer_prefer", "bus_consumer_prefer"
        }).setOutputCol("features");

        Dataset<Row> trainData = assembler.transform(tableData);
        LogisticRegression lr = new LogisticRegression()
                .setElasticNetParam(0)
                .setLabelCol("isclick")
                .setFeaturesCol("features")
                .setMaxIter(60)
                .setRegParam(0.1);

        ParamMap[] paramGrid = new ParamGridBuilder().
                addGrid(lr.maxIter(), new int[]{50, 80})
                .addGrid(lr.regParam(), new double[]{ 0.005, 0.01})
                .addGrid(lr.elasticNetParam(), new double[]{0, 0.1})
                .build();

        BinaryClassificationEvaluator binaryClassificationEvaluator = new BinaryClassificationEvaluator().setLabelCol("isclick");

        CrossValidator crossValidator = new CrossValidator()
                .setEstimator(lr)
                .setEvaluator(binaryClassificationEvaluator)
                .setEstimatorParamMaps(paramGrid).setNumFolds(3);
        CrossValidatorModel cvModel = crossValidator.fit(trainData);

这里我采用LR逻辑回归,Fold设置为3,采用auc评估标准。CrossValidatorModel 就是最佳交叉验证模型。

查看每一次交叉验证的结果

        System.out.println("------------------------");
        ParamMap[] paramMaps = cvModel.getEstimatorParamMaps();
        double[] rocArea = cvModel.avgMetrics();
        for (int i = 0; i < paramMaps.length; i++) {
            System.out.println("------------" + i + "-------------");
            System.out.println("param:" + paramMaps[i]);
            System.out.println("rocArea:" + rocArea[i]);

        }

结果显示

------------0-------------
param:{
	logreg_54489c6b4a15-elasticNetParam: 0.0,
	logreg_54489c6b4a15-maxIter: 50,
	logreg_54489c6b4a15-regParam: 0.005
}
rocArea:0.7561587339955449
------------1-------------
param:{
	logreg_54489c6b4a15-elasticNetParam: 0.1,
	logreg_54489c6b4a15-maxIter: 50,
	logreg_54489c6b4a15-regParam: 0.005
}
rocArea:0.7601808941246924
------------2-------------
param:{
	logreg_54489c6b4a15-elasticNetParam: 0.0,
	logreg_54489c6b4a15-maxIter: 50,
	logreg_54489c6b4a15-regParam: 0.01
}
rocArea:0.7458205940087932
------------3-------------
param:{
	logreg_54489c6b4a15-elasticNetParam: 0.1,
	logreg_54489c6b4a15-maxIter: 50,
	logreg_54489c6b4a15-regParam: 0.01
}
rocArea:0.7519450829980202
------------4-------------
param:{
	logreg_54489c6b4a15-elasticNetParam: 0.0,
	logreg_54489c6b4a15-maxIter: 80,
	logreg_54489c6b4a15-regParam: 0.005
}
rocArea:0.7561587339955449
------------5-------------
param:{
	logreg_54489c6b4a15-elasticNetParam: 0.1,
	logreg_54489c6b4a15-maxIter: 80,
	logreg_54489c6b4a15-regParam: 0.005
}
rocArea:0.7600964897398843
------------6-------------
param:{
	logreg_54489c6b4a15-elasticNetParam: 0.0,
	logreg_54489c6b4a15-maxIter: 80,
	logreg_54489c6b4a15-regParam: 0.01
}
rocArea:0.7458205940087932
------------7-------------
param:{
	logreg_54489c6b4a15-elasticNetParam: 0.1,
	logreg_54489c6b4a15-maxIter: 80,
	logreg_54489c6b4a15-regParam: 0.01
}
rocArea:0.7519466667404875

查看最佳模型的参数

        LogisticRegressionModel logisticRegressionModel = (LogisticRegressionModel) (cvModel.bestModel());
        System.out.println("---------coefficients is :------------------------------");
        System.out.println(logisticRegressionModel.coefficients());
        System.out.println("---------intercept is :------------------------------");
        System.out.println(logisticRegressionModel.intercept());
        System.out.println("------------best param is :------------------------------");
        System.out.println("ElasticNetParam:" + logisticRegressionModel.getElasticNetParam());
        System.out.println("RegParam:" + logisticRegressionModel.getRegParam());
        System.out.println("MaxIter:" + logisticRegressionModel.getMaxIter());

结果显示

---------coefficients is :------------------------------
[-0.020093624027311234,-0.023840313521126238,0.0,-4.870154374284758E-4,0.0,0.065071797632406,0.07375116647372162,0.0,2.1894539394004693E-7,-0.0018188669432944152,-0.0024137353437925675,-7.326062556226837E-6,-6.036644002424079E-7,4.186151682563356E-4,0.0,-0.4216808408256718,-0.25707113499256773,0.06709830247479415,0.0,0.009612099202465701,0.0,0.0,-0.08075082386845575,0.025511061645731456,-0.15643820797235122]
---------intercept is :------------------------------
-3.092110054323254
------------best param is :------------------------------
ElasticNetParam:0.1
RegParam:0.005
MaxIter:50

这样就成功打通spark逻辑回归的交叉验证数据。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值