pyspark 机器学习 逻辑回归 Pipeline

本文介绍使用Spark MLlib对鸢尾花数据集进行预处理、特征工程、模型训练及评估的全过程,包括数据读取、特征转换、模型训练、预测及评估等关键步骤。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

鸢尾花数据集

构建环境导入模块

from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession

spark = SparkSession.builder.config(conf = SparkConf()).getOrCreate()

from pyspark.ml.linalg import Vector,Vectors
from pyspark.sql.types import DoubleType, StructType, StructField
from pyspark.sql import Row,functions
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml import Pipeline
from pyspark.ml.feature import IndexToString, StringIndexer, VectorIndexer, VectorAssembler
from pyspark.ml.classification import LogisticRegression, LogisticRegressionModel,\
BinaryLogisticRegressionSummary,LogisticRegression

数据集示例

5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa

读取鸢尾花数据集

schema = StructType([
    StructField("_c0", DoubleType(), True),
    StructField("_c1", DoubleType(), True),
    StructField("_c2", DoubleType(), True),
    StructField("_c3", DoubleType(), True),
    StructField("_c4", StringType(), True)])
data = spark.read.csv("./datas/iris.data",schema=schema)
data.show(5)
+---+---+---+---+-----------+
|_c0|_c1|_c2|_c3|        _c4|
+---+---+---+---+-----------+
|5.1|3.5|1.4|0.2|Iris-setosa|
|4.9|3.0|1.4|0.2|Iris-setosa|
|4.7|3.2|1.3|0.2|Iris-setosa|
|4.6|3.1|1.5|0.2|Iris-setosa|
|5.0|3.6|1.4|0.2|Iris-setosa|
+---+---+---+---+-----------+
only showing top 5 rows

将特征列转成向量

df_assembler = VectorAssembler(inputCols=['_c0','_c1','_c2',\
                                         '_c3'], outputCol='features')
data = df_assembler.transform(data).select('features','_c4')
data.show(5,truncate=False)
+-----------------+-----------+
|features         |_c4        |
+-----------------+-----------+
|[5.1,3.5,1.4,0.2]|Iris-setosa|
|[4.9,3.0,1.4,0.2]|Iris-setosa|
|[4.7,3.2,1.3,0.2]|Iris-setosa|
|[4.6,3.1,1.5,0.2]|Iris-setosa|
|[5.0,3.6,1.4,0.2]|Iris-setosa|
+-----------------+-----------+
only showing top 5 rows

将标签列转成数值

StringIndexer 将一列按照值出现频率的大小转换成数值,例如该列共有5类值,出现最多的将对应转成0.0,其他按出现频率大小依次转为1.0,2.0,3.0,4.0

labelIndexer = StringIndexer().setInputCol("_c4"). \
    setOutputCol("indexedLabel").fit(data)
data = labelIndexer.transform(data)
data.show(5)
+-----------------+-----------+------------+
|         features|        _c4|indexedLabel|
+-----------------+-----------+------------+
|[5.1,3.5,1.4,0.2]|Iris-setosa|         0.0|
|[4.9,3.0,1.4,0.2]|Iris-setosa|         0.0|
|[4.7,3.2,1.3,0.2]|Iris-setosa|         0.0|
|[4.6,3.1,1.5,0.2]|Iris-setosa|         0.0|
|[5.0,3.6,1.4,0.2]|Iris-setosa|         0.0|
+-----------------+-----------+------------+
only showing top 5 rows

为向量中特征转成索引

VectorIndexer 与 StringIndexer 类似,会将向量中的值转成索引数值,只不过多了一个参数 maxCategories 只有向量中某一对应列的数值种类不超过这个值就会被转化有编号的离散值(index),如果数值种类超过 maxCategories 则不会转化,按连续变量处理

featureIndexer = VectorIndexer(maxCategories=5).setInputCol("features"). \
    setOutputCol("indexedFeatures").fit(data)
data = featureIndexer.transform(data)
data.show(5)
+-----------------+-----------+------------+-----------------+
|         features|        _c4|indexedLabel|  indexedFeatures|
+-----------------+-----------+------------+-----------------+
|[5.1,3.5,1.4,0.2]|Iris-setosa|         0.0|[5.1,3.5,1.4,0.2]|
|[4.9,3.0,1.4,0.2]|Iris-setosa|         0.0|[4.9,3.0,1.4,0.2]|
|[4.7,3.2,1.3,0.2]|Iris-setosa|         0.0|[4.7,3.2,1.3,0.2]|
|[4.6,3.1,1.5,0.2]|Iris-setosa|         0.0|[4.6,3.1,1.5,0.2]|
|[5.0,3.6,1.4,0.2]|Iris-setosa|         0.0|[5.0,3.6,1.4,0.2]|
+-----------------+-----------+------------+-----------------+
only showing top 5 rows

分割数据集

trainData, testData = data.randomSplit([0.7, 0.3])

调用logistic模型进行训练预测

testData 中的 prediction便是预测的结果,probability 为概率值

lr = LogisticRegression(labelCol='indexedLabel',featuresCol='indexedFeatures',\
                        maxIter=100, regParam=0.3, elasticNetParam=0.8).fit(trainData)
testData = lr.transform(testData)
testData .show(5)
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+
|features         |_c4            |indexedLabel|indexedFeatures  |rawPrediction                                                 |probability                                                |prediction|
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+
|[4.6,3.1,1.5,0.2]|Iris-setosa    |0.0         |[4.6,3.1,1.5,0.2]|[0.5171107908618575,-0.34589675872056497,-0.6252372387122376] |[0.5743956461075854,0.24233253347984832,0.1832718204125663]|0.0       |
|[4.9,2.4,3.3,1.0]|Iris-versicolor|1.0         |[4.9,2.4,3.3,1.0]|[-0.24071644549981097,-0.34589675872056497,-0.454668906502002]|[0.3693377344468649,0.33246386640343106,0.2981983991497039]|0.0       |
|[4.9,3.1,1.5,0.1]|Iris-setosa    |0.0         |[4.9,3.1,1.5,0.1]|[0.5460088067214048,-0.34589675872056497,-0.646558280238517]  |[0.5836637411293467,0.2392285726499658,0.1771076862206876] |0.0       |
|[5.0,3.0,1.6,0.2]|Iris-setosa    |0.0         |[5.0,3.0,1.6,0.2]|[0.48785284033489695,-0.34589675872056497,-0.6252372387122376]|[0.5672280025589029,0.24641367880730375,0.1863583186337933]|0.0       |
|[5.0,3.4,1.5,0.2]|Iris-setosa    |0.0         |[5.0,3.4,1.5,0.2]|[0.5171107908618575,-0.34589675872056497,-0.6252372387122376] |[0.5743956461075854,0.24233253347984832,0.1832718204125663]|0.0       |
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+
only showing top 5 rows

将预测值转回成标签

IndexToString 与 StringIndexer 相反,会将数值转成标签

labelConverter = IndexToString(inputCol='prediction',outputCol='predictedLabel',\
                              labels=labelIndexer.labels)
testData = labelConverter.transform(testData)
testData.show(5)
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+--------------+
|features         |_c4            |indexedLabel|indexedFeatures  |rawPrediction                                                 |probability                                                |prediction|predictedLabel|
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+--------------+
|[4.6,3.1,1.5,0.2]|Iris-setosa    |0.0         |[4.6,3.1,1.5,0.2]|[0.5171107908618575,-0.34589675872056497,-0.6252372387122376] |[0.5743956461075854,0.24233253347984832,0.1832718204125663]|0.0       |Iris-setosa   |
|[4.9,2.4,3.3,1.0]|Iris-versicolor|1.0         |[4.9,2.4,3.3,1.0]|[-0.24071644549981097,-0.34589675872056497,-0.454668906502002]|[0.3693377344468649,0.33246386640343106,0.2981983991497039]|0.0       |Iris-setosa   |
|[4.9,3.1,1.5,0.1]|Iris-setosa    |0.0         |[4.9,3.1,1.5,0.1]|[0.5460088067214048,-0.34589675872056497,-0.646558280238517]  |[0.5836637411293467,0.2392285726499658,0.1771076862206876] |0.0       |Iris-setosa   |
|[5.0,3.0,1.6,0.2]|Iris-setosa    |0.0         |[5.0,3.0,1.6,0.2]|[0.48785284033489695,-0.34589675872056497,-0.6252372387122376]|[0.5672280025589029,0.24641367880730375,0.1863583186337933]|0.0       |Iris-setosa   |
|[5.0,3.4,1.5,0.2]|Iris-setosa    |0.0         |[5.0,3.4,1.5,0.2]|[0.5171107908618575,-0.34589675872056497,-0.6252372387122376] |[0.5743956461075854,0.24233253347984832,0.1832718204125663]|0.0       |Iris-setosa   |
+-----------------+---------------+------------+-----------------+--------------------------------------------------------------+-----------------------------------------------------------+----------+--------------+
only showing top 5 rows


上面用模型训练需要很多部步骤,在spark 中可以用 Pipeline 将这些步骤集中起来形成一个管道

Pipeline 的应用

### 读取鸢尾花数据集
schema = StructType([
    StructField("_c0", DoubleType(), True),
    StructField("_c1", DoubleType(), True),
    StructField("_c2", DoubleType(), True),
    StructField("_c3", DoubleType(), True),
    StructField("_c4", StringType(), True)])
data = spark.read.csv("./datas/iris.data",schema=schema)
# data.show(5)

labelIndexer = StringIndexer().setInputCol("_c4"). \
    setOutputCol("indexedLabel").fit(data)
data = labelIndexer.transform(data)
# data.show()

trainData, testData = data.randomSplit([0.7, 0.3])

assembler = VectorAssembler(inputCols=['_c0','_c1','_c2','_c3'], outputCol='features')

featureIndexer = VectorIndexer().setInputCol('features'). \
    setOutputCol("indexedFeatures")


lr = LogisticRegression().\
    setLabelCol("indexedLabel"). \
    setFeaturesCol("indexedFeatures"). \
    setMaxIter(100). \
    setRegParam(0.3). \
    setElasticNetParam(0.8)
# print("LogisticRegression parameters:\n" + lr.explainParams())

labelConverter = IndexToString(). \
    setInputCol("prediction"). \
    setOutputCol("predictedLabel"). \
    setLabels(labelIndexer.labels)
lrPipeline = Pipeline(). \
    setStages([assembler, featureIndexer, lr, labelConverter])

lrPipelineModel = lrPipeline.fit(trainData)

## 保存模型
lrPipelineModel.save('./data/lr_model') 

加载模型并在测试集上预测

### 加载模型
l_r = PipelineModel.load('./data/lr_model')

lrPredictions = l_r.transform(testData)
preRel = lrPredictions.select("predictedLabel","_c4","features","probability")
preRel.show()

evaluator = MulticlassClassificationEvaluator(). \
    setLabelCol("indexedLabel"). \
    setPredictionCol("prediction")
lrAccuracy = evaluator.evaluate(lrPredictions)
lrAccuracy
+---------------+---------------+-----------------+--------------------+
| predictedLabel|            _c4|         features|         probability|
+---------------+---------------+-----------------+--------------------+
|    Iris-setosa|    Iris-setosa|[4.4,3.0,1.3,0.2]|[0.55438709514625...|
|    Iris-setosa|    Iris-setosa|[4.5,2.3,1.3,0.3]|[0.54484378407858...|
|    Iris-setosa|    Iris-setosa|[4.6,3.2,1.4,0.2]|[0.54764212881478...|
|    Iris-setosa|    Iris-setosa|[4.6,3.6,1.0,0.2]|[0.57449259710402...|
|    Iris-setosa|    Iris-setosa|[4.8,3.0,1.4,0.1]|[0.55711960386173...|
|    Iris-setosa|    Iris-setosa|[4.8,3.4,1.9,0.2]|[0.51370352584700...|
|    Iris-setosa|    Iris-setosa|[4.9,3.1,1.5,0.1]|[0.55038245390810...|
|    Iris-setosa|    Iris-setosa|[4.9,3.1,1.5,0.1]|[0.55038245390810...|
|Iris-versicolor|Iris-versicolor|[5.0,2.0,3.5,1.0]|[0.33266831839072...|
|    Iris-setosa|    Iris-setosa|[5.0,3.3,1.4,0.2]|[0.54764212881478...|
|    Iris-setosa|    Iris-setosa|[5.1,3.5,1.4,0.3]|[0.53807472588121...|
|    Iris-setosa|    Iris-setosa|[5.1,3.8,1.9,0.4]|[0.49437428087769...|
|    Iris-setosa|    Iris-setosa|[5.2,3.4,1.4,0.2]|[0.54764212881478...|
|    Iris-setosa|    Iris-setosa|[5.3,3.7,1.5,0.2]|[0.54087961843218...|
|    Iris-setosa|    Iris-setosa|[5.4,3.4,1.7,0.2]|[0.52731180331887...|
|Iris-versicolor|Iris-versicolor|[5.5,2.4,3.8,1.1]|[0.30612463894390...|
|Iris-versicolor|Iris-versicolor|[5.6,2.5,3.9,1.1]|[0.30036420997267...|
|Iris-versicolor|Iris-versicolor|[5.6,2.7,4.2,1.3]|[0.26721917673024...|
|Iris-versicolor|Iris-versicolor|[5.6,3.0,4.1,1.3]|[0.27259161091433...|
| Iris-virginica|Iris-versicolor|[5.6,3.0,4.5,1.5]|[0.23632948715024...|
+---------------+---------------+-----------------+--------------------+
only showing top 20 rows

0.8417582417582418

下面一步一步去转换数据,最后用加载管道模型一步出结果

schema = StructType([
    StructField("_c0", DoubleType(), True),
    StructField("_c1", DoubleType(), True),
    StructField("_c2", DoubleType(), True),
    StructField("_c3", DoubleType(), True),
    StructField("_c4", StringType(), True)])
data = spark.read.csv("./datas/iris.data",schema=schema)
data = data.drop('_c4')
data.show(5)
data1 = lrPipelineModel.stages[0].transform(data)
data1.show(5)
data2 = lrPipelineModel.stages[1].transform(data1)
data2.show(5)
data3 = lrPipelineModel.stages[2].transform(data2)
data3.show(5)
data4 = lrPipelineModel.stages[3].transform(data3)
data4.show(5)

l_r = PipelineModel.load('./data/lr_model')
lrPredictions = l_r.transform(data)
lrPredictions.show(5)
+---+---+---+---+
|_c0|_c1|_c2|_c3|
+---+---+---+---+
|5.1|3.5|1.4|0.2|
|4.9|3.0|1.4|0.2|
|4.7|3.2|1.3|0.2|
|4.6|3.1|1.5|0.2|
|5.0|3.6|1.4|0.2|
+---+---+---+---+
only showing top 5 rows

+---+---+---+---+-----------------+
|_c0|_c1|_c2|_c3|         features|
+---+---+---+---+-----------------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|
+---+---+---+---+-----------------+
only showing top 5 rows

+---+---+---+---+-----------------+-----------------+
|_c0|_c1|_c2|_c3|         features|  indexedFeatures|
+---+---+---+---+-----------------+-----------------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|[5.1,3.5,1.4,0.2]|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|[4.9,3.0,1.4,0.2]|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|[4.7,3.2,1.3,0.2]|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|[4.6,3.1,1.5,0.2]|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|[5.0,3.6,1.4,0.2]|
+---+---+---+---+-----------------+-----------------+
only showing top 5 rows

+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+
|_c0|_c1|_c2|_c3|         features|  indexedFeatures|       rawPrediction|         probability|prediction|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|[5.1,3.5,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|[4.9,3.0,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|[4.7,3.2,1.3,0.2]|[0.50343051284394...|[0.55438709514625...|       0.0|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|[4.6,3.1,1.5,0.2]|[0.44890228364925...|[0.54087961843218...|       0.0|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|[5.0,3.6,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+
only showing top 5 rows

+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
|_c0|_c1|_c2|_c3|         features|  indexedFeatures|       rawPrediction|         probability|prediction|predictedLabel|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|[5.1,3.5,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|[4.9,3.0,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|[4.7,3.2,1.3,0.2]|[0.50343051284394...|[0.55438709514625...|       0.0|   Iris-setosa|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|[4.6,3.1,1.5,0.2]|[0.44890228364925...|[0.54087961843218...|       0.0|   Iris-setosa|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|[5.0,3.6,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
only showing top 5 rows

+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
|_c0|_c1|_c2|_c3|         features|  indexedFeatures|       rawPrediction|         probability|prediction|predictedLabel|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
|5.1|3.5|1.4|0.2|[5.1,3.5,1.4,0.2]|[5.1,3.5,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
|4.9|3.0|1.4|0.2|[4.9,3.0,1.4,0.2]|[4.9,3.0,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
|4.7|3.2|1.3|0.2|[4.7,3.2,1.3,0.2]|[4.7,3.2,1.3,0.2]|[0.50343051284394...|[0.55438709514625...|       0.0|   Iris-setosa|
|4.6|3.1|1.5|0.2|[4.6,3.1,1.5,0.2]|[4.6,3.1,1.5,0.2]|[0.44890228364925...|[0.54087961843218...|       0.0|   Iris-setosa|
|5.0|3.6|1.4|0.2|[5.0,3.6,1.4,0.2]|[5.0,3.6,1.4,0.2]|[0.47616639824660...|[0.54764212881478...|       0.0|   Iris-setosa|
+---+---+---+---+-----------------+-----------------+--------------------+--------------------+----------+--------------+
only showing top 5 rows
<think>我们正在讨论PySpark机器学习库的使用指南。根据引用[2],PySpark是ApacheSpark的PythonAPI,它结合了Python的易用性和Spark的分布式计算能力,能够高效处理PB级数据集。Spark基于内存计算的特性使其比传统HadoopMapReduce快10-100倍,支持流处理、SQL查询、机器学习和图计算。在PySpark中,机器学习功能主要由`pyspark.ml`模块提供,该模块提供了一套高级API,用于构建机器学习管道。与`pyspark.mllib`(基于RDD的API)不同,`pyspark.ml`主要基于DataFrame,提供了更统一和易用的接口。###主要概念1.**DataFrame**:这是SparkSQL中的核心数据结构,用于表示分布式数据集。在机器学习中,DataFrame的列表示特征、标签等。2.**Transformer**:一个可以将一个DataFrame转换为另一个DataFrame的算法。例如,一个训练好的模型就是一个Transformer,它可以将包含特征的DataFrame转换为包含预测结果的DataFrame。3.**Estimator**:一个可以根据DataFrame拟合出一个Transformer的算法。例如,一个机器学习算法就是一个Estimator,它可以在训练数据上训练出一个模型(Transformer)。4.**Pipeline**:将多个Transformer和Estimator链接在一起,形成一个工作流。例如,一个典型的机器学习流程可能包括特征提取、特征转换、模型训练等步骤,这些步骤可以通过Pipeline组合起来。5.**Parameter**:所有Estimator和Transformer共享的通用API,用于指定参数。###基本步骤使用`pyspark.ml`进行机器学习通常包括以下步骤:1.准备数据:读取数据,将其转换为DataFrame。2.数据预处理:使用Transformer(如`StringIndexer`、`VectorAssembler`等)对数据进行转换。3.定义模型:选择一个Estimator(如`LogisticRegression`、`RandomForestClassifier`等)。4.训练模型:使用训练数据拟合Estimator,得到一个Transformer(即模型)。5.评估模型:使用评估器(如`BinaryClassificationEvaluator`)评估模型性能。6.使用模型:用训练好的模型对新数据进行预测。###示例:逻辑回归下面是一个使用逻辑回归的简单示例:```pythonfrompyspark.sqlimportSparkSessionfrompyspark.ml.classificationimportLogisticRegressionfrompyspark.ml.featureimportVectorAssemblerfrompyspark.mlimportPipeline#创建Spark会话spark=SparkSession.builder.appName("ml-example").getOrCreate()#示例数据(实际中可能从文件读取)data=[(1.0,2.0,0.0),(2.0,1.0,1.0),(3.0,4.0,1.0),(4.0,3.0,0.0)]df=spark.createDataFrame(data,["feature1","feature2","label"])#将特征组合成一个向量assembler=VectorAssembler(inputCols=["feature1","feature2"],outputCol="features")#创建逻辑回归模型lr=LogisticRegression(featuresCol="features",labelCol="label")#构建Pipelinepipeline=Pipeline(stages=[assembler,lr])#划分训练集和测试集train_df,test_df=df.randomSplit([0.7,0.3])#训练模型model=pipeline.fit(train_df)#预测predictions=model.transform(test_df)predictions.select("features","label","prediction").show()```###模型选择与调优PySpark提供了`CrossValidator`和`TrainValidationSplit`进行模型选择和调优。它们需要指定一个Estimator、一组参数网格(`ParamGridBuilder`)和一个评估器。```pythonfrompyspark.ml.tuningimportParamGridBuilder,CrossValidatorfrompyspark.ml.evaluationimportBinaryClassificationEvaluator#定义评估器evaluator=BinaryClassificationEvaluator(labelCol="label")#创建参数网格paramGrid=ParamGridBuilder()\.addGrid(lr.regParam,[0.01,0.1])\.build()#创建交叉验证器crossval=CrossValidator(estimator=pipeline,estimatorParamMaps=paramGrid,evaluator=evaluator,numFolds=3)#运行交叉验证,选择最佳参数集cvModel=crossval.fit(train_df)```###特征工程PySpark提供了丰富的特征处理工具,包括:-`StringIndexer`:将字符串标签转换为数值索引。-`OneHotEncoder`:将类别特征转换为二进制向量。-`StandardScaler`:标准化特征(均值为0,标准差为1)。-`MinMaxScaler`:将特征缩放到指定范围(如[0,1])。-`VectorAssembler`:将多个特征列组合成一个特征向量。-`PCA`:主成分分析降维。###模型持久化训练好的模型可以保存和加载:```pythonmodel.save("path/to/model")loaded_model=PipelineModel.load("path/to/model")```###注意事项-分布式计算:PySpark机器学习算法是为分布式计算设计的,因此需要确保数据分区合理,避免数据倾斜。-内存管理:由于Spark基于内存计算,需要合理配置内存参数。以上内容涵盖了PySpark机器学习库的基本使用指南。更多详细信息可以参考[官方文档](http://spark.apache.org/docs/latest/api/python/reference/pyspark.ml.html)。§§相问题§§1.PySpark中`pyspark.ml`和`pyspark.mllib`有什么区别?2.在PySpark中如何处理类别特征?3.PySpark支持哪些常见的机器学习算法?4.如何保存和加载PySpark机器学习模型?5.在分布式环境中进行机器学习需要注意哪些问题?</think>### PySpark机器学习使用指南 PySpark机器学习模块`pyspark.ml`提供了分布式机器学习框架,主要包含以下核心组件: #### 1. **核心概念** - **DataFrame**:所有数据的容器(替代旧版RDD) - **Transformer**:将DataFrame转换为新DataFrame的算法(如模型) - **Estimator**:拟合数据生成Transformer的算法(如训练模型) - **Pipeline**:将多个步骤组合成工作流 - **Evaluator**:模型评估器(如`BinaryClassificationEvaluator`) #### 2. **基础工作流程** ```python from pyspark.ml import Pipeline from pyspark.ml.classification import LogisticRegression from pyspark.ml.feature import VectorAssembler, StringIndexer # 特征向量化 assembler = VectorAssembler(inputCols=["age", "income"], outputCol="features") # 标签编码 indexer = StringIndexer(inputCol="label", outputCol="indexedLabel") # 逻辑回归模型 lr = LogisticRegression(featuresCol="features", labelCol="indexedLabel") # 构建管道 pipeline = Pipeline(stages=[assembler, indexer, lr]) # 训练模型 model = pipeline.fit(train_df) # 预测 predictions = model.transform(test_df) ``` #### 3. **特征工程工具** | 工具类 | 功能 | 示例 | |--------|------|------| | `VectorAssembler` | 组合特征向量 | `inputCols=["col1","col2"]` | | `StandardScaler` | 标准化特征 | `withMean=True` | | `PCA` | 主成分分析 | `k=3`(降维维度) | | `StringIndexer` | 字符串转索引 | `inputCol="category"` | | `OneHotEncoder` | 独热编码 | `dropLast=False` | #### 4. **常见算法** - **分类**:`LogisticRegression`, `DecisionTreeClassifier`, `RandomForestClassifier` - **回归**:`LinearRegression`, `GBTRegressor` - **聚类**:`KMeans`, `BisectingKMeans` - **推荐**:`ALS`(交替最小二乘) - **降维**:`PCA`, `LSH` #### 5. **模型评估** ```python from pyspark.ml.evaluation import MulticlassClassificationEvaluator evaluator = MulticlassClassificationEvaluator( labelCol="indexedLabel", predictionCol="prediction", metricName="accuracy" ) accuracy = evaluator.evaluate(predictions) print(f"模型准确率: {accuracy:.2%}") ``` #### 6. **超参数调优** ```python from pyspark.ml.tuning import ParamGridBuilder, CrossValidator paramGrid = ParamGridBuilder() \ .addGrid(lr.regParam, [0.01, 0.1]) \ .addGrid(lr.maxIter, [10, 50]) \ .build() crossval = CrossValidator( estimator=pipeline, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=3 ) cv_model = crossval.fit(train_df) ``` #### 7. **模型持久化** ```python # 保存模型 model.save("hdfs:///models/lr_model") # 加载模型 from pyspark.ml import PipelineModel loaded_model = PipelineModel.load("hdfs:///models/lr_model") ``` #### 8. **分布式计算优化** - 使用`cache()`缓存频繁访问的数据 - 通过`repartition()`调整数据分区 - 设置并行度:`spark.conf.set("spark.sql.shuffle.partitions", 200)` - 优先选择基于DataFrame的API(比RDD API快5-10倍)[^2] ### 典型应用场景 1. **大规模分类任务**:使用`RandomForestClassifier`处理GB级数据 2. **实时推荐系统**:`ALS`算法协同过滤 3. **异常检测**:`KMeans`聚类分析日志数据 4. **特征工程流水线**:组合多个预处理步骤 > **注意**:PySpark机器学习要求数据必须转换为`Vector`类型,且所有特征需合并为单列`features`[^1]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

viziviuz

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值