Spark ML(3):回归算法实现(线性回归、逻辑回归)

该博客主要介绍了Spark环境配置与回归算法代码实现。环境配置涉及spark2.1.0 - cdh5.7.0、cdh5.7.0等。环境准备包括搭建spark客户端调试环境、创建scala项目和添加pom依赖。代码实现部分给出测试数据样例,重点展示了线性回归和逻辑回归。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

一、环境配置

1.spark2.1.0-cdh5.7.0(自编译)

2.cdh5.7.0

3.scala2.11.8

4.centos6.4

二、环境准备

1.spark客户端调试环境搭建

参考:https://blog.youkuaiyun.com/u010886217/article/details/83279157

2.创建scala项目

参考:https://blog.youkuaiyun.com/u010886217/article/details/84332961

3.添加pom依赖

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
  <modelVersion>4.0.0</modelVersion>
  <groupId>sparktest</groupId>
  <artifactId>sparktest</artifactId>
  <version>1.0-SNAPSHOT</version>
  <inceptionYear>2008</inceptionYear>
  <properties>
    <scala.version>2.11.8</scala.version>
    <kafka.version>0.9.0.0</kafka.version>
    <hbase.version>1.2.0-cdh5.7.0</hbase.version>
    <spark.version>2.1.0</spark.version>
    <hadoop.version>2.6.0-cdh5.7.0</hadoop.version>
  </properties>

  <repositories>
<!--    <repository>-->
<!--      <id>scala-tools.org</id>-->
<!--      <name>Scala-Tools Maven2 Repository</name>-->
<!--      <url>http://scala-tools.org/repo-releases</url>-->
<!--    </repository>-->
    <repository>
      <id>cloudera</id>
      <url>https://repository.cloudera.com/artifactory/cloudera-repos/</url>
    </repository>
  </repositories>

<!--  <pluginRepositories>-->
<!--    <pluginRepository>-->
<!--      <id>scala-tools.org</id>-->
<!--      <name>Scala-Tools Maven2 Repository</name>-->
<!--      <url>http://scala-tools.org/repo-releases</url>-->
<!--    </pluginRepository>-->
<!--  </pluginRepositories>-->

  <dependencies>
    <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-mllib -->
    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-mllib_2.11</artifactId>
      <version>2.1.0</version>
<!--      <scope>runtime</scope>-->
    </dependency>

    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-streaming_2.11</artifactId>
      <version>2.1.0</version>
    </dependency>

<!--    <dependency>-->
<!--      <groupId>org.apache.spark</groupId>-->
<!--      <artifactId>spark-streaming-kafka-0-8_2.11</artifactId>-->
<!--      <version>2.1.0</version>-->
<!--    </dependency>-->
    <dependency>
      <groupId>org.scala-lang</groupId>
      <artifactId>scala-library</artifactId>
      <version>${scala.version}</version>
    </dependency>

<!--    <dependency>-->
<!--      <groupId>org.apache.kafka</groupId>-->
<!--      <artifactId>kafka_2.11</artifactId>-->
<!--      <version>${kafka.version}</version>-->
<!--    </dependency>-->

<!--    <dependency>-->
<!--      <groupId>junit</groupId>-->
<!--      <artifactId>junit</artifactId>-->
<!--      <version>4.4</version>-->
<!--      <scope>test</scope>-->
<!--    </dependency>-->

<!--    &lt;!&ndash; https://mvnrepository.com/artifact/net.jpountz.lz4/lz4 &ndash;&gt;-->
<!--    <dependency>-->
<!--      <groupId>net.jpountz.lz4</groupId>-->
<!--      <artifactId>lz4</artifactId>-->
<!--      <version>1.3.0</version>-->
<!--    </dependency>-->

<!--    &lt;!&ndash; https://mvnrepository.com/artifact/org.apache.hbase/hbase-client &ndash;&gt;-->
<!--    <dependency>-->
<!--      <groupId>org.apache.hbase</groupId>-->
<!--      <artifactId>hbase-client</artifactId>-->
<!--      <version>${hbase.version}</version>-->
<!--    </dependency>-->
    <!-- https://mvnrepository.com/artifact/org.apache.hbase/hbase-server -->
<!--    <dependency>-->
<!--      <groupId>org.apache.hbase</groupId>-->
<!--      <artifactId>hbase-server</artifactId>-->
<!--      <version>${hbase.version}</version>-->
<!--    </dependency>-->

<!--    <dependency>-->
<!--      <groupId>org.apache.zookeeper</groupId>-->
<!--      <artifactId>zookeeper</artifactId>-->
<!--      <version>3.4.5-cdh5.7.0</version>-->
<!--      <type>pom</type>-->
<!--    </dependency>-->

<!--    <dependency>-->
<!--      <groupId>mysql</groupId>-->
<!--      <artifactId>mysql-connector-java</artifactId>-->
<!--      <version>5.1.27</version>-->
<!--    </dependency>-->



<!--    <dependency>-->
<!--      <groupId>org.apache.hadoop</groupId>-->
<!--      <artifactId>hadoop-client</artifactId>-->
<!--      <version>${hadoop.version}</version>-->
<!--    </dependency>-->

    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-sql_2.11</artifactId>
      <version>${spark.version}</version>
    </dependency>

    <dependency>
      <groupId>org.apache.spark</groupId>
      <artifactId>spark-core_2.11</artifactId>
      <version>${spark.version}</version>
    </dependency>

  </dependencies>

  <build>
    <sourceDirectory>src/main/scala</sourceDirectory>
<!--    <testSourceDirectory>src/test/scala</testSourceDirectory>-->
    <plugins>
      <plugin>
        <groupId>org.scala-tools</groupId>
        <artifactId>maven-scala-plugin</artifactId>
        <executions>
          <execution>
            <goals>
              <goal>compile</goal>
              <goal>testCompile</goal>
            </goals>
          </execution>
        </executions>
        <configuration>
          <scalaVersion>${scala.version}</scalaVersion>
          <args>
            <arg>-target:jvm-1.5</arg>
          </args>
        </configuration>
      </plugin>
      <plugin>
        <groupId>org.apache.maven.plugins</groupId>
        <artifactId>maven-eclipse-plugin</artifactId>
        <configuration>
          <downloadSources>true</downloadSources>
          <buildcommands>
            <buildcommand>ch.epfl.lamp.sdt.core.scalabuilder</buildcommand>
          </buildcommands>
          <additionalProjectnatures>
            <projectnature>ch.epfl.lamp.sdt.core.scalanature</projectnature>
          </additionalProjectnatures>
          <classpathContainers>
            <classpathContainer>org.eclipse.jdt.launching.JRE_CONTAINER</classpathContainer>
            <classpathContainer>ch.epfl.lamp.sdt.launching.SCALA_CONTAINER</classpathContainer>
          </classpathContainers>
        </configuration>
      </plugin>
    </plugins>
  </build>
  <reporting>
    <plugins>
      <plugin>
        <groupId>org.scala-tools</groupId>
        <artifactId>maven-scala-plugin</artifactId>
        <configuration>
          <scalaVersion>${scala.version}</scalaVersion>
        </configuration>
      </plugin>
    </plugins>
  </reporting>
</project>

三、代码实现

1.测试数据样例

position;square;price;direction;type;name;
0;190;20000;0;4室2厅2卫;中信城(别墅);
0;190;20000;0;4室2厅2卫;中信城(别墅);
5;400;15000;0;4室3厅3卫;融创上城;
0;500;15000;0;5室3厅2卫;中海莱茵东郡;
5;500;15000;0;5室3厅4卫;融创上城(别墅);
1;320;15000;1;1室1厅1卫;长江花园;
0;143;12000;0;3室2厅2卫;融创上城;
0;200;10000;0;4室3厅2卫;中海莱茵东郡(别墅);
0;207;9000;0;4室3厅4卫;中海莱茵东郡;
0;130;8500;0;3室2厅2卫;伟峰东第;
5;150;7000;0;3室2厅2卫;融创上城;
2;178;6000;0;4室2厅2卫;鸿城国际花园;
5;190;6000;0;3室2厅2卫;亚泰豪苑C栋;
1;150;6000;0;5室1厅2卫;通安新居A区;
2;165;6000;0;3室2厅2卫;万科惠斯勒小镇;
0;64;5500;0;1室1厅1卫;保利中央公园;
2;105;5500;0;2室2厅1卫;虹馆;
1;160;5300;0;3室2厅1卫;昊源高格蓝湾;
2;170;5100;0;4室2厅2卫;亚泰鼎盛国际;
...

2.线性回归

package sparktest
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkConf, SparkContext}

import scala.util.Random

object Main {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("linear").setMaster("local")
    val sc = new SparkContext(conf)
    val spark = SparkSession.builder().config(conf).getOrCreate()

    val file = spark.read.format("csv").option("sep",";").option("header","true").load("house.csv")
    //val file: Nothing = spark.read.format("csv").option("sep", ";").option("header", "true").load("house.csv")
    import spark.implicits._
    //打乱顺序
    val rand = new Random()
    val data = file.select("square","price").map(
      row => (row.getAs[String](0).toDouble,row.getString(1).toDouble,rand.nextDouble()))
      .toDF("square","price","rand").sort("rand") //强制类型转换过程

    val ass = new VectorAssembler().setInputCols(Array("square")).setOutputCol("features")
    val dataset = ass.transform(data)//特征包装

    val Array(train,test) = dataset.randomSplit(Array(0.8,0.2))//拆分成训练数据集和测试数据集
    //train.show()
    //线性回归
    val lr = new LinearRegression().setStandardization(true).setMaxIter(10)
      .setFeaturesCol("features")
      .setLabelCol("price")
    //创建一个对象
    val model = lr.fit(train) //训练

    model.transform(test).show()


  }
}
结果:
|square|  price|                rand|features|        prediction|
+------+-------+--------------------+--------+------------------+
|  64.0| 1400.0|0.006545997025104056|  [64.0]| 1871.626837251842|
| 100.0| 2600.0|0.006070889056102979| [100.0]|1910.2354430535167|
|  10.0|  450.0|0.016279200291292373|  [10.0]|  1813.71392854933|
|  60.0| 1700.0| 0.01773595114007931|  [60.0]| 1867.336992162767|
| 150.0| 7000.0| 0.01799868562447937| [150.0]|1963.8585066669539|
| 320.0|15000.0|0.022266421888484267| [320.0]|  2146.17692295264|
|  42.0| 1200.0| 0.03158087604172155|  [42.0]|1848.0326892619296|
|  88.0| 1800.0|  0.0340325321865349|  [88.0]| 1897.365907786292|
|  60.0| 1600.0| 0.05842848976518067|  [60.0]| 1867.336992162767|
| 154.0| 3700.0| 0.08695690147815338| [154.0]| 1968.148351756029|
|  96.0| 2500.0| 0.08956069761188501|  [96.0]|1905.9455979644417|
|  63.0| 1400.0|  0.1058435529752908|  [63.0]|1870.5543759795733|
| 100.0| 2200.0| 0.12881102655257837| [100.0]|1910.2354430535167|
|  20.0|  500.0| 0.13298961275676147|  [20.0]|1824.4385412720173|
| 148.0| 2800.0|  0.1347286681517027| [148.0]|1961.7135841224165|
|  92.0| 2950.0|  0.1418523082181563|  [92.0]|1901.6557528753667|
|  96.0| 2300.0| 0.14272158486886666|  [96.0]|1905.9455979644417|
|  70.0| 1500.0| 0.14371869433210183|  [70.0]|1878.0616048854545|
|  18.0|  600.0|  0.1523863581129299|  [18.0]|  1822.29361872748|
|  20.0|  800.0| 0.17536250365755057|  [20.0]|1824.4385412720173|
+------+-------+--------------------+--------+------------------+
only showing top 20 rows

3.逻辑回归

package sparktest
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg
import org.apache.spark.ml.regression.LinearRegression
//import org.apache.spark.m
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkConf, SparkContext}

import scala.util.Random

object Main {
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("linear").setMaster("local")
    val sc = new SparkContext(conf)
    val spark = SparkSession.builder().config(conf).getOrCreate()

    val file = spark.read.format("csv").option("sep",";").option("header","true").load("house.csv")
    //val file: Nothing = spark.read.format("csv").option("sep", ";").option("header", "true").load("house.csv")
    import spark.implicits._
    //打乱顺序
    val rand = new Random()
    val data = file.select("square","price").map(
      row => (row.getAs[String](0).toDouble,row.getString(1).toDouble,rand.nextDouble()))
      .toDF("square","price","rand").sort("rand") //强制类型转换过程

    val ass = new VectorAssembler().setInputCols(Array("square")).setOutputCol("features")
    val dataset = ass.transform(data)//特征包装

    val Array(train,test) = dataset.randomSplit(Array(0.8,0.2))//拆分成训练数据集和测试数据集


    //逻辑回归
    val lr = new LogisticRegression().setLabelCol("price").setFeaturesCol("features")
      .setRegParam(0.3).setElasticNetParam(0.8).setMaxIter(10)
    val model = lr.fit(train)
    model.transform(test).show()
    val s = model.summary.totalIterations
    println(s"iter: ${s}")
  }
}

结果:
|square| price|                rand|features|       rawPrediction|         probability|prediction|
+------+------+--------------------+--------+--------------------+--------------------+----------+
|  43.0|1600.0|0.001716305364737214|  [43.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  60.0|1600.0|0.010588427013326074|  [60.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  61.0|1300.0|0.043301012076277345|  [61.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  60.0|1600.0| 0.05231439852503761|  [60.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  20.0| 600.0| 0.05386280768045393|  [20.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  89.0|2800.0|  0.0650227911769532|  [89.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  60.0|1500.0| 0.06793901574354433|  [60.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  90.0|2500.0| 0.07541330585084804|  [90.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  65.0|1300.0| 0.07727780227514891|  [65.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  25.0|1300.0| 0.09515681816587895|  [25.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  50.0|1400.0| 0.08681645057310305|  [50.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  60.0|1600.0| 0.10042920576336689|  [60.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  41.0|1500.0| 0.11564005495013441|  [41.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  92.0|2950.0| 0.11751726539452112|  [92.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  71.0|1600.0| 0.12520507959550664|  [71.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
| 110.0|2700.0| 0.13631041935375054| [110.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  10.0| 450.0|  0.1429523132917182|  [10.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  65.0|1700.0| 0.15676743340088062|  [65.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  90.0|2100.0| 0.18187817541586593|  [90.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
|  62.0|1500.0| 0.19149303955455987|  [62.0]|[-0.0253687225534...|[6.85377863744923...|    1500.0|
+------+------+--------------------+--------+--------------------+--------------------+----------+
only showing top 20 rows

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值