pyspark 机器学习 逻辑回归 Pipeline

本文介绍使用Spark MLlib对鸢尾花数据集进行预处理、特征工程、模型训练及评估的全过程,包括数据读取、特征转换、模型训练、预测及评估等关键步骤。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

鸢尾花数据集

构建环境导入模块

from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession

spark = SparkSession.builder.config(conf = SparkConf()).getOrCreate()

from pyspark.ml.linalg import Vector,Vectors
from pyspark.sql.types import DoubleType, StructType, StructField
from pyspark.sql import Row,functions
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml import Pipeline
from pyspark.ml.feature import IndexToString, StringIndexer, VectorIndexer, VectorAssembler
from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel,\
BinaryLogisticRegressionSummary,LogisticRegression

数据集示例

5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa

读取鸢尾花数据集

schema = StructType([
    StructField("_c0", DoubleType(), True),
    StructField("_c1", DoubleType(), True),
    StructField("_c2", DoubleType(), True),
    StructField("_c3", DoubleType(), True),
    StructField("_c4", StringType(), True)])
data = spark.read.csv("./datas/iris.data",schema=schema)
data.show(5)
+---+---+---+---+-----------+
|_c0|_c1|_c2|_c3|        _c4|
+---+---+---+---+-----------+
|5.1|3.5|1.4|0.2|Iris-setosa|
|4.9|3.0|1.4|0.2|Iris-setosa|
|4.7|3.2|1.3|0.2|Iris-setosa|
|4.6|3.1|1.5|0.2|Iris-setosa|
|5.0|3.6|1.4|0.2|Iris-setosa|
+---+---+---+---+-----------+
only showing top 5 rows

将特征列转成向量

df_assembler = VectorAssembler(inputCols=['_c0','_c1','_c2',\
                                         '_c3'], outputCol='features')
data = df_assembler.transform(data).select('features','_c4')
data.show(5,truncate=False)
+-----------------+-----------+
|features         |_c4        |
+-----------------+-----------+
|[5.1,3.5,1.4,0.2]|Iris-setosa|
|[4.9,3.0,1.4,0.2]|Iris-setosa|
|[4.7,3.2,1.3,0.2]|Iris-setosa|
|[4.6,3.1,1.5,0.2]|Iris-setosa|
|[5.0,3.6,1.4,0.2]|Iris-setosa|
+-----------------+-----------+
only showing top 5 rows

将标签列转成数值

StringIndexer 将一列按照值出现频率的大小转换成数值,例如该列共有5类值,出现最多的将对应转成0.0,其他按出现频率大小依次转为1.0,2.0,3.0,4.0

labelIndexer = StringIndexer().setInputCol("_c4"). \
    setOutputCol("indexedLabel").fit(data)
data = labelIndexer.transform(data)
data.show(5)
+-----------------+-----------+------------+
|         features|        _c4|indexedLabel|
+-----------------+-----------+------------+
|[5.1,3.5,1.4,0.2]|Iris-setosa|         0.0|
|[4.9,3.0,1.4,0.2]|Iris-setosa|         0.0|
|[4.7,3.2,1.3,0.2]|Iris-setosa|         0.0|
|[4.6,3.1,1.5,0.2]|Iris-setosa|         0.0|
|[5.0,3.6,1.4,0.2]|Iris-setosa|         0.0|
+-----------------+-----------+------------+
only showing top 5 rows

为向量中特征转成索引

VectorIndexer 与 StringIndexer 类似,会将向量中的值转成索引数值,只不过多了一个参数 maxCategories 只有向量中某一对应列的数值种类不超过这个值就会被转化有编号的离散值(index),如果数值种类超过 maxCategories 则不会转化,按连续变量处理

featureIndexer = VectorIndexer(maxCategories=5).setInputCol("features"). \
    setOutputCol("indexedFeatures").fit(data)
data = featureIndexer.transform(data)
data.show(5)
+-----------------+-----------+------------+-----------------+
|         features|        _c4|indexedLabel|  indexedFeatures|
+-----------------+-----------+------------+-----------------+
|[5.1,3.5,1.4,0.2]|Iris-setosa|         0.0|[5.1,3.5,1.4,0.2]|
|[4.9,3.0,1.4,0.2]|Iris-setosa|         0.0|[4.9,3.0,1.4,0.2]|
|[4.7,3.2,1.3,0.2]|Iris-setosa|         0.0|[4.7,3.2,1.3,0.2]|
|[4.6,3.1,1.5,0.2]|Iris-setosa|         0.0|[4.6,3.1,1.5,0.2]|
|[5.0,3.6,1.4,0.2]|Iris-setosa|         0.0|[5.0,3.6,1.4,0.2]|
+-----------------+-----------+------------+-----------------+
only showing top 5 rows

分割数据集

trainData, testData = data.randomSplit([0.7, 0.3])

调用logistic模型进行训练预测

testData 中的 prediction便是预测的结果,probability 为概率值

lr = LogisticRegression(labelCol='indexedLabel',featuresCol='indexedFeatures',\
                        maxIter=100, regParam=0.3, elasticNetParam=0.8).fit(trainData)
testData = lr.transform(testData)
testData .show(5)
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+
|features         |_c4            |indexedLabel|indexedFeatures  |rawPrediction                                                 |probability                                                |prediction|
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+
|[4.6,3.1,1.5,0.2]|Iris-setosa    |0.0         |[4.6,3.1,1.5,0.2]|[0.5171107908618575,-0.34589675872056497,-0.6252372387122376] |[0.5743956461075854,0.24233253347984832,0.1832718204125663]|0.0       |
|[4.9,2.4,3.3,1.0]|Iris-versicolor|1.0         |[4.9,2.4,3.3,1.0]|[-0.24071644549981097,-0.34589675872056497,-0.454668906502002]|[0.3693377344468649,0.33246386640343106,0.2981983991497039]|0.0       |
|[4.9,3.1,1.5,0.1]|Iris-setosa    |0.0         |[4.9,3.1,1.5,0.1]|[0.5460088067214048,-0.34589675872056497,-0.646558280238517]  |[0.5836637411293467,0.2392285726499658,0.1771076862206876] |0.0       |
|[5.0,3.0,1.6,0.2]|Iris-setosa    |0.0         |[5.0,3.0,1.6,0.2]|[0.48785284033489695,-0.34589675872056497,-0.6252372387122376]|[0.5672280025589029,0.24641367880730375,0.1863583186337933]|0.0       |
|[5.0,3.4,1.5,0.2]|Iris-setosa    |0.0         |[5.0,3.4,1.5,0.2]|[0.5171107908618575,-0.34589675872056497,-0.6252372387122376] |[0.5743956461075854,0.24233253347984832,0.1832718204125663]|0.0       |
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+
only showing top 5 rows

将预测值转回成标签

IndexToString 与 StringIndexer 相反,会将数值转成标签

labelConverter = IndexToString(inputCol='prediction',outputCol='predictedLabel',\
                              labels=labelIndexer.labels)
testData = labelConverter.transform(testData)
testData.show(5)
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+--------------+
|features         |_c4            |indexedLabel|indexedFeatures  |rawPrediction                                                 |probability                                                |prediction|predictedLabel|
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+--------------+
|[4.6,3.1,1.5,0.2]|Iris-setosa    |0.0         |[4.6,3.1,1.5,0.2]|[0.5171107908618575,-0.34589675872056497,-0.6252372387122376] |[0.5743956461075854,0.24233253347984832,0.1832718204125663]|0.0       |Iris-setosa   |
|[4.9,2.4,3.3,1.0]|Iris-versicolor|1.0         |[4.9,2.4,3.3,1.0]|[-0.24071644549981097,-0.34589675872056497,-0.454668906502002]|[0.3693377344468649,0.33246386640343106,0.2981983991497039]|0.0       |Iris-setosa   |
|[4.9,3.1,1.5,0.1]|Iris-setosa    |0.0         |[4.9,3.1,1.5,0.1]|[0.5460088067214048,-0.34589675872056497,-0.646558280238517]  |[0.5836637411293467,0.2392285726499658,0.1771076862206876] |0.0       |Iris-setosa   |
|[5.0,3.0,1.6,0.2]|Iris-setosa    |0.0         |[5.0,3.0,1.6,0.2]|[0.48785284033489695,-0.34589675872056497,-0.6252372387122376]|[0.5672280025589029,0.24641367880730375,0.1863583186337933]|0.0       |Iris-setosa   |
|[5.0,3.4,1.5,0.2]|Iris-setosa    |0.0         |[5.0,3.4,1.5,0.2]|[0.5171107908618575,-0.34589675872056497,-0.6252372387122376] |[0.5743956461075854,0.24233253347984832,0.1832718204125663]|0.0       |Iris-setosa   |
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+--------------+
only showing top 5 rows


上面用模型训练需要很多部步骤,在spark 中可以用 Pipeline 将这些步骤集中起来形成一个管道

Pipeline 的应用

### 读取鸢尾花数据集
schema = StructType([
    StructField("_c0", DoubleType(), True),
    StructField("_c1", DoubleType(), True),
    StructField("_c2", DoubleType(), True),
    StructField("_c3", DoubleType(), True),
    StructField("_c4", StringType(), True)])
data = spark.read.csv("./datas/iris.data",schema=schema)
# data.show(5)

labelIndexer = StringIndexer().setInputCol("_c4"). \
    setOutputCol("indexedLabel").fit(data)
data = labelIndexer.transform(data)
# data.show()

trainData, testData = data.randomSplit([0.7, 0.3])

assembler = VectorAssembler(inputCols=['_c0','_c1','_c2','_c3'], outputCol='features')

featureIndexer = VectorIndexer().setInputCol('features'). \
    setOutputCol("indexedFeatures")


lr = LogisticRegression().\
    setLabelCol("indexedLabel"). \
    setFeaturesCol("indexedFeatures"). \
    setMaxIter(100). \
    setRegParam(0.3). \
    setElasticNetParam(0.8)
# print("LogisticRegression parameters:\n" + lr.explainParams())

labelConverter = IndexToString(). \
    setInputCol("prediction"). \
    setOutputCol("predictedLabel"). \
    setLabels(labelIndexer.labels)
lrPipeline = Pipeline(). \
    setStages([assembler, featureIndexer, lr, labelConverter])

lrPipelineModel = lrPipeline.fit(trainData)

## 保存模型
lrPipelineModel.save('./data/lr_model') 

加载模型并在测试集上预测

### 加载模型
l_r = PipelineModel.load('./data/lr_model')

lrPredictions = l_r.transform(testData)
preRel = lrPredictions.select("predictedLabel","_c4","features","probability")
preRel.show()

evaluator = MulticlassClassificationEvaluator(). \
    setLabelCol("indexedLabel"). \
    setPredictionCol("prediction")
lrAccuracy = evaluator.evaluate(lrPredictions)
lrAccuracy
+---------------+---------------+-----------------+--------------------+
| predictedLabel|            _c4|         features|         probability|
+---------------+---------------+-----------------+--------------------+
|    Iris-setosa|    Iris-setosa|[4.4,3.0,1.3,0.2]|[0.55438709514625...|
|    Iris-setosa|    Iris-setosa|[4.5,2.3,1.3,0.3]|[0.54484378407858...|
|    Iris-setosa|    Iris-setosa|[4.6,3.2,1.4,0.2]|[0.54764212881478...|
|    Iris-setosa|    Iris-setosa|[4.6,3.6,1.0,0.2]|[0.57449259710402...|
|    Iris-setosa|    Iris-setosa|[4.8,3.0,1.4,0.1]|[0.55711960386173...|
|    Iris-setosa|    Iris-setosa|[4.8,3.4,1.9,0.2]|[0.51370352584700...|
|    Iris-setosa|    Iris-setosa|[4.9,3.1,1.5,0.1]|[0.55038245390810...|
|    Iris-setosa|    Iris-setosa|[4.9,3.1,1.5,0.1]|[0.55038245390810...|
|Iris-versicolor|Iris-versicolor|[5.0,2.0,3.5,1.0]|[0.33266831839072...|
|    Iris-setosa|    Iris-setosa|[5.0,3.3,1.4,0.2]|[0.54764212881478...|
|    Iris-setosa|    Iris-setosa|[5.1,3.5,1.4,0.3]|[0.53807472588121...|
|    Iris-setosa|    Iris-setosa|[5.1,3.8,1.9,0.4]|[0.49437428087769...|
|    Iris-setosa|    Iris-setosa|[5.2,3.4,1.4,0.2]|[0.54764212881478...|
|    Iris-setosa|    Iris-setosa|[5.3,3.7,1.5,0.2]|[0.54087961843218...|
|    Iris-setosa|    Iris-setosa|[5.4,3.4,1.7,0.2]|[0.52731180331887...|
|Iris-versicolor|Iris-versicolor|[5.5,2.4,3.8,1.1]|[0.30612463894390...|
|Iris-versicolor|Iris-versicolor|[5.6,2.5,3.9,1.1]|[0.30036420997267...|
|Iris-versicolor|Iris-versicolor|[5.6,2.7,4.2,1.3]|[0.26721917673024...|
|Iris-versicolor|Iris-versicolor|[5.6,3.0,4.1,1.3]|[0.27259161091433...|
| Iris-virginica|Iris-versicolor|[5.6,3.0,4.5,1.5]|[0.23632948715024...|
+---------------+---------------+-----------------+--------------------+
only showing top 20 rows

0.8417582417582418

下面一步一步去转换数据,最后用加载管道模型一步出结果

schema = StructType([
    StructField("_c0", DoubleType(), True),
    StructField("_c1", DoubleType(), True),
    StructField("_c2", DoubleType(), True),
    StructField("_c3", DoubleType(), True),
    StructField("_c4", StringType(), True)])
data = spark.read.csv("./datas/iris.data",schema=schema)
data = data.drop('_c4')
data.show(5)
data1 = lrPipelineModel.stages[0].transform(data)
data1.show(5)
data2 = lrPipelineModel.stages[1].transform(data1)
data2.show(5)
data3 = lrPipelineModel.stages[2].transform(data2)
data3.show(5)
data4 = lrPipelineModel.stages[3].transform(data3)
data4.show(5)

l_r = PipelineModel.load('./data/lr_model')
lrPredictions = l_r.transform(data)
lrPredictions.show(5)
+---+---+---+---+
|_c0|_c1|_c2|_c3|
+---+---+---+---+
|5.1|3.5|1.4|0.2|
|4.9|3.0|1.4|0.2|
|4.7|3.2|1.3|0.2|
|4.6|3.1|1.5|0.2|
|5.0|3.6|1.4|0.2|
+---+---+---+---+
only showing top 5 rows

+---+---+---+---+-----------------+
|_c0|_c1|_c2|_c3|         features|
+---+---+---+---+-----------------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|
+---+---+---+---+-----------------+
only showing top 5 rows

+---+---+---+---+-----------------+-----------------+
|_c0|_c1|_c2|_c3|         features|  indexedFeatures|
+---+---+---+---+-----------------+-----------------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|[5.1,3.5,1.4,0.2]|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|[4.9,3.0,1.4,0.2]|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|[4.7,3.2,1.3,0.2]|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|[4.6,3.1,1.5,0.2]|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|[5.0,3.6,1.4,0.2]|
+---+---+---+---+-----------------+-----------------+
only showing top 5 rows

+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+
|_c0|_c1|_c2|_c3|         features|  indexedFeatures|       rawPrediction|         probability|prediction|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|[5.1,3.5,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|[4.9,3.0,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|[4.7,3.2,1.3,0.2]|[0.50343051284394...|[0.55438709514625...|       0.0|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|[4.6,3.1,1.5,0.2]|[0.44890228364925...|[0.54087961843218...|       0.0|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|[5.0,3.6,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+
only showing top 5 rows

+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
|_c0|_c1|_c2|_c3|         features|  indexedFeatures|       rawPrediction|         probability|prediction|predictedLabel|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|[5.1,3.5,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|[4.9,3.0,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|[4.7,3.2,1.3,0.2]|[0.50343051284394...|[0.55438709514625...|       0.0|   Iris-setosa|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|[4.6,3.1,1.5,0.2]|[0.44890228364925...|[0.54087961843218...|       0.0|   Iris-setosa|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|[5.0,3.6,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
only showing top 5 rows

+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
|_c0|_c1|_c2|_c3|         features|  indexedFeatures|       rawPrediction|         probability|prediction|predictedLabel|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|[5.1,3.5,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|[4.9,3.0,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|[4.7,3.2,1.3,0.2]|[0.50343051284394...|[0.55438709514625...|       0.0|   Iris-setosa|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|[4.6,3.1,1.5,0.2]|[0.44890228364925...|[0.54087961843218...|       0.0|   Iris-setosa|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|[5.0,3.6,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
only showing top 5 rows
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

viziviuz

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值