Spark 多项式逻辑回归__多分类

本文介绍了一种使用Spark MLlib进行多项逻辑回归的实现方法。通过加载数据、预处理、训练模型及评估等步骤,展示了如何在多分类问题上应用逻辑回归。
package Spark_MLlib

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.SparkSession


object 多项式逻辑回归__多分类 {
  val spark=SparkSession.builder().master("local").getOrCreate()
  import spark.implicits._  //支持把一个RDD隐式转换为一个DataFrame
  def main(args: Array[String]): Unit = {
    val df =spark.sparkContext.textFile("file:///home/soyo/桌面/spark编程测试数据/soyo.txt")
      .map(_.split(",")).map(x=>data_schema(Vectors.dense(x(0).toDouble,x(1).toDouble,x(2).toDouble,x(3).toDouble),x(4))).toDF()
   // df.show(150)
    val labelIndexer=new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(df)
    val featureIndexer=new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").fit(df)  //目的在特征向量中建类别索引
    val Array(trainData,testData)=df.randomSplit(Array(0.7,0.3))
    val lr=new LogisticRegression().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8).setFamily("multinomial")//设置elasticnet混合参数为0.8,setFamily("multinomial"):设置为多项逻辑回归,不设置setFamily为二项逻辑回归
    val labelConverter=new IndexToString().setInputCol("prediction").setOutputCol("predictionLabel").setLabels(labelIndexer.labels)

    val lrPipeline=new Pipeline().setStages(Array(labelIndexer,featureIndexer,lr,labelConverter))
    val lrPipeline_Model=lrPipeline.fit(trainData)
    val lrPrediction=lrPipeline_Model.transform(testData)
    lrPrediction.show(150)
    // lrPrediction.take(100).foreach(println)
    //模型评估
    val evaluator=new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction")
    val lrAccuracy=evaluator.evaluate(lrPrediction)
    println("准确率为: "+lrAccuracy)
    val lrError=1-lrAccuracy
    println("错误率为: "+lrError)
    val LRmodel=lrPipeline_Model.stages(2).asInstanceOf[LogisticRegressionModel]
    println("二项逻辑回归模型系数矩阵: "+LRmodel.coefficientMatrix)
    println("二项逻辑回归模型的截距向量: "+LRmodel.interceptVector)
    println("类的数量(标签可以使用的值): "+LRmodel.numClasses)
    println("模型所接受的特征的数量: "+LRmodel.numFeatures)
  //多项式逻辑回归不包含对模型的摘要总结
    println(LRmodel.hasSummary)

  }

}

结果:

+-----------------+-----+------------+-----------------+--------------------+--------------------+----------+---------------+
|         features|label|indexedLabel|  indexedFeatures|       rawPrediction|         probability|prediction|predictionLabel|
+-----------------+-----+------------+-----------------+--------------------+--------------------+----------+---------------+
|[4.4,3.2,1.3,0.2]|soyo1|         1.0|[4.4,3.2,1.3,0.2]|[0.06313829278191...|[0.23858281707128...|       1.0|          soyo1|
|[4.6,3.4,1.4,0.3]|soyo1|         1.0|[4.6,3.4,1.4,0.3]|[0.06313829278191...|[0.23750012598226...|       1.0|          soyo1|
|[4.7,3.2,1.6,0.2]|soyo1|         1.0|[4.7,3.2,1.6,0.2]|[0.06313829278191...|[0.24710416166321...|       1.0|          soyo1|
|[4.8,3.4,1.6,0.2]|soyo1|         1.0|[4.8,3.4,1.6,0.2]|[0.06313829278191...|[0.23716995683018...|       1.0|          soyo1|
|[4.8,3.4,1.9,0.2]|soyo1|         1.0|[4.8,3.4,1.9,0.2]|[0.06313829278191...|[0.24567798276462...|       1.0|          soyo1|
|[4.9,2.4,3.3,1.0]|soyo2|         0.0|[4.9,2.4,3.3,1.0]|[0.06313829278191...|[0.38071131817453...|       0.0|          soyo2|
|[5.0,3.2,1.2,0.2]|soyo1|         1.0|[5.0,3.2,1.2,0.2]|[0.06313829278191...|[0.23576075216827...|       1.0|          soyo1|
|[5.0,3.5,1.3,0.3]|soyo1|         1.0|[5.0,3.5,1.3,0.3]|[0.06313829278191...|[0.22978111243935...|       1.0|          soyo1|
|[5.2,4.1,1.5,0.1]|soyo1|         1.0|[5.2,4.1,1.5,0.1]|[0.06313829278191...|[0.19523110424215...|       1.0|          soyo1|
|[5.4,3.9,1.3,0.4]|soyo1|         1.0|[5.4,3.9,1.3,0.4]|[0.06313829278191...|[0.21630436073381...|       1.0|          soyo1|
|[5.5,2.4,3.8,1.1]|soyo2|         0.0|[5.5,2.4,3.8,1.1]|[0.06313829278191...|[0.39807479409636...|       0.0|          soyo2|
|[5.5,2.5,4.0,1.3]|soyo2|         0.0|[5.5,2.5,4.0,1.3]|[0.06313829278191...|[0.40810357240132...|       0.0|          soyo2|
|[5.6,2.8,4.9,2.0]|soyo3|         2.0|[5.6,2.8,4.9,2.0]|[0.06313829278191...|[0.44454733071968...|       0.0|          soyo2|
|[5.7,2.9,4.2,1.3]|soyo2|         0.0|[5.7,2.9,4.2,1.3]|[0.06313829278191...|[0.39634982244233...|       0.0|          soyo2|
|[5.8,2.6,4.0,1.2]|soyo2|         0.0|[5.8,2.6,4.0,1.2]|[0.06313829278191...|[0.39930520027794...|       0.0|          soyo2|
|[5.8,2.7,4.1,1.0]|soyo2|         0.0|[5.8,2.7,4.1,1.0]|[0.06313829278191...|[0.38762610877473...|       0.0|          soyo2|
|[5.8,2.7,5.1,1.9]|soyo3|         2.0|[5.8,2.7,5.1,1.9]|[0.06313829278191...|[0.44792417666537...|       0.0|          soyo2|
|[5.9,3.0,5.1,1.8]|soyo3|         2.0|[5.9,3.0,5.1,1.8]|[0.06313829278191...|[0.43418725338764...|       0.0|          soyo2|
|[6.0,2.2,4.0,1.0]|soyo2|         0.0|[6.0,2.2,4.0,1.0]|[0.06313829278191...|[0.40634099537710...|       0.0|          soyo2|
|[6.0,2.7,5.1,1.6]|soyo2|         0.0|[6.0,2.7,5.1,1.6]|[0.06313829278191...|[0.43688076686419...|       0.0|          soyo2|
|[6.0,3.4,4.5,1.6]|soyo2|         0.0|[6.0,3.4,4.5,1.6]|[0.06313829278191...|[0.39704954911011...|       0.0|          soyo2|
|[6.2,2.2,4.5,1.5]|soyo2|         0.0|[6.2,2.2,4.5,1.5]|[0.06313829278191...|[0.43847273913421...|       0.0|          soyo2|
|[6.2,2.8,4.8,1.8]|soyo3|         2.0|[6.2,2.8,4.8,1.8]|[0.06313829278191...|[0.43518321759857...|       0.0|          soyo2|
|[6.3,2.7,4.9,1.8]|soyo3|         2.0|[6.3,2.7,4.9,1.8]|[0.06313829278191...|[0.44055947195014...|       0.0|          soyo2|
|[6.3,2.9,5.6,1.8]|soyo3|         2.0|[6.3,2.9,5.6,1.8]|[0.06313829278191...|[0.44715759200377...|       0.0|          soyo2|
|[6.3,3.4,5.6,2.4]|soyo3|         2.0|[6.3,3.4,5.6,2.4]|[0.06313829278191...|[0.45196576310313...|       0.0|          soyo2|
|[6.4,2.8,5.6,2.1]|soyo3|         2.0|[6.4,2.8,5.6,2.1]|[0.06313829278191...|[0.46017875340546...|       0.0|          soyo2|
|[6.4,2.8,5.6,2.2]|soyo3|         2.0|[6.4,2.8,5.6,2.2]|[0.06313829278191...|[0.46321910727428...|       0.0|          soyo2|
|[6.4,3.1,5.5,1.8]|soyo3|         2.0|[6.4,3.1,5.5,1.8]|[0.06313829278191...|[0.43862320280893...|       0.0|          soyo2|
|[6.4,3.2,4.5,1.5]|soyo2|         0.0|[6.4,3.2,4.5,1.5]|[0.06313829278191...|[0.40056786531830...|       0.0|          soyo2|
|[6.5,3.0,5.5,1.8]|soyo3|         2.0|[6.5,3.0,5.5,1.8]|[0.06313829278191...|[0.44199581778961...|       0.0|          soyo2|
|[6.6,2.9,4.6,1.3]|soyo2|         0.0|[6.6,2.9,4.6,1.3]|[0.06313829278191...|[0.40579282648595...|       0.0|          soyo2|
|[6.7,2.5,5.8,1.8]|soyo3|         2.0|[6.7,2.5,5.8,1.8]|[0.06313829278191...|[0.46287803722998...|       0.0|          soyo2|
|[6.7,3.0,5.2,2.3]|soyo3|         2.0|[6.7,3.0,5.2,2.3]|[0.06313829278191...|[0.45387841693477...|       0.0|          soyo2|
|[6.7,3.1,4.7,1.5]|soyo2|         0.0|[6.7,3.1,4.7,1.5]|[0.06313829278191...|[0.40924150360290...|       0.0|          soyo2|
|[6.7,3.3,5.7,2.5]|soyo3|         2.0|[6.7,3.3,5.7,2.5]|[0.06313829278191...|[0.45972648058424...|       0.0|          soyo2|
|[6.8,3.0,5.5,2.1]|soyo3|         2.0|[6.8,3.0,5.5,2.1]|[0.06313829278191...|[0.45251276088924...|       0.0|          soyo2|
|[6.8,3.2,5.9,2.3]|soyo3|         2.0|[6.8,3.2,5.9,2.3]|[0.06313829278191...|[0.45975331380088...|       0.0|          soyo2|
|[6.9,3.2,5.7,2.3]|soyo3|         2.0|[6.9,3.2,5.7,2.3]|[0.06313829278191...|[0.45642868507279...|       0.0|          soyo2|
|[7.2,3.0,5.8,1.6]|soyo3|         2.0|[7.2,3.0,5.8,1.6]|[0.06313829278191...|[0.44031726493318...|       0.0|          soyo2|
|[7.2,3.2,6.0,1.8]|soyo3|         2.0|[7.2,3.2,6.0,1.8]|[0.06313829278191...|[0.44483171938259...|       0.0|          soyo2|
|[7.6,3.0,6.6,2.1]|soyo3|         2.0|[7.6,3.0,6.6,2.1]|[0.06313829278191...|[0.47047723863543...|       0.0|          soyo2|
|[7.7,3.0,6.1,2.3]|soyo3|         2.0|[7.7,3.0,6.1,2.3]|[0.06313829278191...|[0.46845272424381...|       0.0|          soyo2|
|[7.7,3.8,6.7,2.2]|soyo3|         2.0|[7.7,3.8,6.7,2.2]|[0.06313829278191...|[0.45233124776236...|       0.0|          soyo2|
+-----------------+-----+------------+-----------------+--------------------+--------------------+----------+---------------+

准确率为: 0.36458333333333337
错误率为: 0.6354166666666666
二项逻辑回归模型系数矩阵: 3 x 4 CSCMatrix
(1,1) 0.35559564188466614
(1,2) -0.203185158868005
(1,3) -0.43876460704959996
(2,3) 0.0283914830858408
二项逻辑回归模型的截距向量: [0.06313829278191783,0.1708622138778958,-0.23400050665981365]
类的数量(标签可以使用的值): 3
模型所接受的特征的数量: 4
false

转载于:https://www.cnblogs.com/soyo/p/7787647.html

<think>好的,用户想了解如何使用逻辑回归进行预测或分类。首先,我需要回忆逻辑回归的基本概念和步骤。逻辑回归虽然名字里有“回归”,但其实是分类算法,常用于二分类,也可以通过一些方法扩展到多分类。 根据引用中的内容,引用[2]提到逻辑回归是二分类算法,使用sigmoid函数预测概率。而引用[1]和引用[3]分别展示了用sklearn实现多分类Spark处理文本分类的例子。用户的问题可能包括二分类和多分类,需要覆盖这些方面。 接下来,步骤应该包括数据准备、模型训练、预测和评估。需要确保提到特征工程,比如标准化、文本处理如TF-IDF,这在引用[3]里有提到。模型训练部分,sklearn的LogisticRegression是常用工具,如引用[1]的例子。多分类问题可以通过One-vs-Rest或多项式分布,如引用[1]所述。 可能用户是刚入门的,所以代码示例要简单明了。同时,要解释关键参数,比如正则化项C,这在避免过拟合时很重要。评估指标如准确率、混淆矩阵也需要提到,帮助用户了解模型表现。 还要考虑应用场景,比如用户提到的文本分类,如引用[3]中的例子,可能需要处理文本数据转换成特征向量,这部分需要说明清楚,比如使用TF-IDF或CountVectorizer。 用户可能遇到的问题包括数据预处理不当、特征选择不合适,或者参数调整,这些可以放在注意事项里。另外,要区分二分类和多分类的实现方法,确保用户知道如何调整模型应对不同问题。 最后,生成相关问题时要覆盖逻辑回归的原理、多分类方法、特征工程技巧和评估指标,帮助用户深入学习和解决问题。</think>### 如何使用逻辑回归进行预测或分类? 逻辑回归是一种广泛使用的分类算法,适用于二分类和多分类任务。以下是具体步骤和示例: --- #### **1. 数据准备** - **特征工程**:将数据转换为数值特征,并进行标准化/归一化(如使用`StandardScaler`)。 - **文本处理**(如引用[3]):将文本合并为一列后,通过TF-IDF或词袋模型转换为数值特征。 - **划分数据集**:按比例(如8:2)分为训练集和测试集。 --- #### **2. 模型训练** **二分类示例(引用[2])**: ```python from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split # 假设X为特征矩阵,y为二分类标签 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) model = LogisticRegression() model.fit(X_train, y_train) ``` **多分类示例(引用[1])**: ```python # 使用One-vs-Rest策略处理多分类 model = LogisticRegression(multi_class='ovr') model.fit(X_train, y_train) ``` --- #### **3. 预测评估** - **预测概率**:输出属于各类别的概率。 ```python probabilities = model.predict_proba(X_test) ``` - **直接分类结果**: ```python predictions = model.predict(X_test) ``` - **评估指标**: ```python from sklearn.metrics import accuracy_score, confusion_matrix print("准确率:", accuracy_score(y_test, predictions)) print("混淆矩阵:", confusion_matrix(y_test, predictions)) ``` --- #### **4. 关键参数说明** - `penalty`:正则化类型(默认`l2`),防止过拟合。 - `C`:正则化强度,值越小正则化越强。 - `solver`:优化算法(如`liblinear`适用于小数据集,`sag`适合大数据集)。 --- #### **应用场景** 1. **二分类**:如预测用户是否购买商品、邮件是否为垃圾邮件[^2]。 2. **多分类**:如新闻分类、手写数字识别(需调整参数`multi_class`)[^1]。 3. **文本分类**:通过TF-IDF处理文本后进行分类(如引用[3]中的新闻标题分类)。 --- #### **注意事项** - 逻辑回归假设特征目标呈线性关系,非线性问题需结合多项式特征。 - 类别不平衡时需设置`class_weight='balanced'`。 - 高维稀疏数据(如文本)建议搭配正则化使用。 ---
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符  | 博主筛选后可见
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值