<转>spark下线性模型 spark.mllib

本文详细介绍了Spark MLlib中的线性方法,包括分类、回归等算法,并提供了具体的实现代码示例,如SVM、逻辑回归和线性回归等。

我还是参考官方的文档来写这个部分,顺便梳理下原理,给出对应代码及运行结果,一点也不复杂。

数学公式

许多的机器学习算法实际上可以被写成凸优化的问题,比如说寻找凸函数 f 的极小值,它取决于权重向量w,那么我们可以将优化目标函数写成: 
这里写图片描述 
这里 xiRd 是训练数据, yiR 是它们对应的标签,线性方法可以表示成 L(w;x,y) ,有几类mllib中的分类和回归算法都可以归为这一类。 
目标函数由两个部分,正则项,控制模型的复杂度,以及loss(亏损函数),它评估训练数据的模型误差。

Loss functions

mllib中支持的亏损函数和它们的梯度(对w)为: 
这里写图片描述

正则项

正则项鼓励简单的模型,避免overfitting. 
这里写图片描述
L2约束一般来说比L1约束更平滑,然而,L1约束可以帮助提升权重项的稀疏性,因此可以获得更小的容易解释的模型,在特征选择方面非常有用。Elastic网是L1和L2的结合。

优化

使用凸优化的方法来对目标函数进行优化,Spark.mllib使用两种方法,分别是SGD和L-BFGS,我们在优化这一章来进一部解释。目前,大多数算法的APIs支持随机梯度下降算法SGD,少数支持L-BFGS算法。

分类

常见的有二分类的问题,将样本分为正样本和负样本,超过两类就是多分类问题。 
在spark.mllib中,支持两种线性分类方法,分别是SVMs以及逻辑回归。线性SVMs的方法仅仅支持二分类,逻辑回归同时还支持多分类的问题。对于两种方法,spark.mllib都支持L1和L2的规则项。 
训练集用MLlib中的LabeledPoint的RDD来表示,所有的标签都是从0开始的,需要注意的是,在二分类问题的数学表示中,负样本写成-1, 然而这里我们把负样本写成0.

线性SVM

损失函数定义为: 
这里写图片描述 
默认情况下,线性SVM使用的是L2约束,当然同时也可以使用L1约束,这种情况下它就变成线性规划问题。 
这是我在Intellij下运行通过的Scala版本的代码,可读性非常高。

import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.util.MLUtils

object testSVM{

  def main(args:Array[String]): Unit ={


    val conf = new SparkConf()
      .setMaster("local[2]")
      .setAppName("testSVM")
    var sc = new SparkContext(conf)

    // Load training data in LIBSVM format.
    val data = MLUtils.loadLibSVMFile(sc, "/home/hadoop/spark/data/mllib/sample_libsvm_data.txt")

    // Split data into training (60%) and test (40%).
    val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
    val training = splits(0).cache()
    val test = splits(1)

    // Run training algorithm to build the model
    val numIterations = 100
    val model = SVMWithSGD.train(training, numIterations)

    // Clear the default threshold.
    model.clearThreshold()

    // Compute raw scores on the test set.
    val scoreAndLabels = test.map { point =>
      val score = model.predict(point.features)
      (score, point.label)
    }

    // Get evaluation metrics.
    val metrics = new BinaryClassificationMetrics(scoreAndLabels)
    val auROC = metrics.areaUnderROC()

    println("Area under ROC = " + auROC)

    // Save and load model
    model.save(sc, "myModelPath")
    val sameModel = SVMModel.load(sc, "myModelPath")

  }

}

 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50

运行可以得到结果为: 
Area under ROC = 1.0 
注意输入的格式为: 
label index1:value1 index2:value2 … 
它是稀疏的 
Alt text

逻辑回归

逻辑回归中的损失函数可以表示成: 
这里写图片描述 
对于二分类问题,算法输出一个二值的逻辑回归模型,给定一个数据点,表示成x,通过运用逻辑函数 
这里写图片描述 
表示,其中 z=WTx , 如果 f(z)>0.5 那么输出就为正,样本为正样本,否则为负,样本为负样本。 
二值逻辑回归也可以推广到多模态,用于多分类问题。比如说有K个可能的输出,其中一个输出作为pivot,其余的K-1个用于与之区分。在spark.mllib中,第一个class 0 作为pivot类。 
多分类的问题由K-1个二值的逻辑回归组成,给定一个新的数据点,我们将运行K-1个模型,拥有最大的概率的类将被选为预测的类。 
我们实现两个算法来求解逻辑回归问题: 一个是mini-batch的梯度下降算法,一个是L-BFGS算法。我们推荐使用L-BFGS。

/**
  * Created by hadoop on 16-2-16.
  */
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, LogisticRegressionModel}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLUtils
object testLR {
  def main(args:Array[String]): Unit = {
    val conf = new SparkConf()
      .setMaster("local[2]")
      .setAppName("testSVM")
    var sc = new SparkContext(conf)
    // Load training data in LIBSVM format.
    val data = MLUtils.loadLibSVMFile(sc, "/home/hadoop/spark/data/mllib/sample_libsvm_data.txt")

    // Split data into training (60%) and test (40%).
    val splits = data.randomSplit(Array(0.6, 0.4), seed = 11L)
    val training = splits(0).cache()
    val test = splits(1)

    // Run training algorithm to build the model
    val model = new LogisticRegressionWithLBFGS()
      .setNumClasses(10)
      .run(training)

    // Compute raw scores on the test set.
    val predictionAndLabels = test.map { case LabeledPoint(label, features) =>
      val prediction = model.predict(features)
      (prediction, label)
    }

    // Get evaluation metrics.
    val metrics = new MulticlassMetrics(predictionAndLabels)
    val precision = metrics.precision
    println("Precision = " + precision)

    // Save and load model
    model.save(sc, "myModelPath")
    val sameModel = LogisticRegressionModel.load(sc, "myModelPath")
  }
}
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44

输出为: 
Precision = 1.0

Regression

Linear least squares, Lasso, and ridge regression

linear least squares是回归问题中最常见的构造,损失函数可以写成 
这里写图片描述 
可以使用不同类型的规则项,比如说ordinary least squares 或者linear least squares 它们没有使用规则项, 
ridge regression使用L2规则项,Lasso使用的L1规则项。对于所有的模型,平均损失以及训练误差为 
这里写图片描述 
也就是平均squared error。 
下面的例子,首先载入数据,解析成LabeledPoint的RDD格式,随后使用LinearRegressionWithSGD来构造一个简单的线性model来预测值。用squared error来表示拟合情况。

import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.regression.LinearRegressionModel
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
import org.apache.spark.mllib.linalg.Vectors
object Regression {
  def main(args:Array[String]): Unit ={
    val conf = new SparkConf()
      .setMaster("local[2]")
      .setAppName("testSVM")
    var sc = new SparkContext(conf)

    // Load and parse the data
    val data = sc.textFile("/home/hadoop/spark/data/mllib/ridge-data/lpsa.data")
    val parsedData = data.map { line =>
      val parts = line.split(',')
      LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble)))
    }.cache()

    // Building the model
    val numIterations = 100
    val model = LinearRegressionWithSGD.train(parsedData, numIterations)

    // Evaluate model on training examples and compute training error
    val valuesAndPreds = parsedData.map { point =>
      val prediction = model.predict(point.features)
      (point.label, prediction)
    }
    val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean()
    println("training Mean Squared Error = " + MSE)

    // Save and load model
    model.save(sc, "myModelPath")
    val sameModel = LinearRegressionModel.load(sc, "myModelPath")
  }
}
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36

输出结果为: 
training Mean Squared Error = 6.207597210613578

注意其中的,创建dense vector的方法

// Create a dense vector (1.0, 0.0, 3.0).
val dv: Vector = Vectors.dense(1.0, 0.0, 3.0)
// Create a sparse vector (1.0, 0.0, 3.0) by specifying its indices and values corresponding to nonzero entries.
val sv1: Vector = Vectors.sparse(3, Array(0, 2), Array(1.0, 3.0))
// Create a sparse vector (1.0, 0.0, 3.0) by specifying its nonzero entries.
val sv2: Vector = Vectors.sparse(3, Seq((0, 1.0), (2, 3.0))) //第一项为向量的长度
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

Streaming linear regression

当数据按照流的方式到来时,采用在线回归模型的方式是很好的,目前mllib支持streaming 线性回归。这个拟合和离线的方式差不多,只是每来一批数据,就拟合一次,所以可以不断的更新。

开发者实现

mllib实现了一个简单的分布式的SGD,基于原始的梯度下降,算法的规则项为regParam,以及不同的参数用于随机梯度下降(stepSize, numIterations, miniBatchFraction)。对于每一项,都支持三种可能的规则项(none,L1,L2) 
For Logistic Regression, L-BFGS version is implemented under LogisticRegressionWithLBFGS, and this version supports both binary and multinomial Logistic Regression while SGD version only supports binary Logistic Regression. However, L-BFGS version doesn’t support L1 regularization but SGD one supports L1 regularization. When L1 regularization is not required, L-BFGS version is strongly recommended since it converges faster and more accurately compared to SGD by approximating the inverse Hessian matrix using quasi-Newton method.

Algorithms are all implemented in Scala:

SVMWithSGD
LogisticRegressionWithLBFGS
LogisticRegressionWithSGD
LinearRegressionWithSGD
RidgeRegressionWithSGD
LassoWithSGD

参考文献 
http://spark.apache.org/docs/latest/mllib-linear-methods.html

package ads import common.PortraitCommon.{ck_dim, ckdriver, ckpassword, ckurl, ckuser, dim_user_info_tmp} import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.col import utils.{ClickhouseUtils, SparkUtils} object ads_car_brand_cut { def main(args: Array[String]): Unit = { // TODO: 构建spark环境 val spark: SparkSession = SparkUtils.getSpark("ads_car_brand_cut") // TODO: 取数据 val dwdreadData = spark.read.format("jdbc") .option("url", ckurl) .option("user", ckuser) .option("password", ckpassword) .option("dbtable", "tieta_v2_dim.dim_user_info") .option("driver", ckdriver) .load() // val mysqlurl = "jdbc:mysql://192.168.3.117:3309/lol" // val mysqlusername = "root" // val mysqlpassword = "123456" // TODO: 过滤品牌 val filterDF: Dataset[Row] = dwdreadData.filter(col("car_brand_model") =!= "未知") // TODO: 对车辆型号进行统计 val resDF: DataFrame = filterDF.groupBy("car_brand_model").count() .withColumnRenamed("count", "brand_count") resDF.show() // // TODO: 将数据写入 ClickHouse // ClickhouseUtils.writeClickHouse(resDF, "tieta_v2_ads", "ads_car_brand_cut") // println("数据成功写入 ClickHouse") // resDF.write // .format("jdbc") // .option("url", mysqlurl) // .option("dbtable", "ads_car_brand_cut") // .option("user", mysqlusername) // .option("password", mysqlpassword) // .option("driver", "com.mysql.cj.jdbc.Driver") // .mode("append") // .save() // // println("Data successfully written to mysql") } }我想在虚拟机的spark运行这个文件,我的整个项目打包的jar包在、srv/untitled3-1.0-SNAPSHOT.jar <?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>org.example</groupId> <artifactId>untitled3</artifactId> <version>1.0-SNAPSHOT</version> <properties> <maven.compiler.source>8</maven.compiler.source> <maven.compiler.target>8</maven.compiler.target> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <!-- 统一版本管理 --> <spark.version>3.1.1</spark.version> <hive.version>2.3.7</hive.version> <!-- 与 Spark 3.1.1 兼容的 Hive 版本 --> </properties> <dependencies> <!-- Spark Core --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_2.12</artifactId> <version>${spark.version}</version> </dependency> <!-- Spark SQL --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-sql_2.12</artifactId> <version>${spark.version}</version> </dependency> <!-- Spark Hive Support --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-hive_2.12</artifactId> <version>${spark.version}</version> </dependency> <!-- Spark MLlib --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-mllib_2.12</artifactId> <version>${spark.version}</version> </dependency> <!-- Hadoop Client --> <dependency> <groupId>org.apache.hadoop</groupId> <artifactId>hadoop-client</artifactId> <version>3.2.0</version> </dependency> <!-- Hudi Spark Bundle --> <dependency> <groupId>org.apache.hudi</groupId> <artifactId>hudi-spark3.1-bundle_2.12</artifactId> <version>0.12.0</version> </dependency> <!-- ClickHouse JDBC --> <dependency> <groupId>ru.yandex.clickhouse</groupId> <artifactId>clickhouse-jdbc</artifactId> <version>0.3.2</version> </dependency> <!-- Jackson Databind --> <dependency> <groupId>com.fasterxml.jackson.core</groupId> <artifactId>jackson-databind</artifactId> <version>2.10.0</version> </dependency> <!-- Jackson Core --> <dependency> <groupId>com.fasterxml.jackson.core</groupId> <artifactId>jackson-core</artifactId> <version>2.10.0</version> </dependency> <!-- Akka Actor --> <dependency> <groupId>com.typesafe.akka</groupId> <artifactId>akka-actor_2.12</artifactId> <version>2.6.16</version> </dependency> <!-- JSch --> <dependency> <groupId>com.jcraft</groupId> <artifactId>jsch</artifactId> <version>0.1.51</version> </dependency> <!-- Typesafe Config --> <dependency> <groupId>com.typesafe</groupId> <artifactId>config</artifactId> <version>1.4.2</version> </dependency> <!-- MySQL Connector --> <dependency> <groupId>mysql</groupId> <artifactId>mysql-connector-java</artifactId> <version>8.0.29</version> </dependency> <!-- DOM4J --> <dependency> <groupId>org.dom4j</groupId> <artifactId>dom4j</artifactId> <version>2.1.4</version> </dependency> </dependencies> <build> <plugins> <!-- Maven Shade Plugin 用于创建包含所有依赖的fat JAR --> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-shade-plugin</artifactId> <version>3.2.4</version> <executions> <execution> <phase>package</phase> <goals> <goal>shade</goal> </goals> <configuration> <transformers> <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> <!-- 如果需要设置主类,可以在这里指定 --> <!-- <mainClass>your.main.Class</mainClass> --> </transformer> </transformers> <filters> <filter> <artifact>*:*</artifact> <excludes> <exclude>META-INF/*.SF</exclude> <exclude>META-INF/*.DSA</exclude> <exclude>META-INF/*.RSA</exclude> </excludes> </filter> </filters> </configuration> </execution> </executions> </plugin> </plugins> </build> </project>这是我的pom文件,应该没问题吧
最新发布
08-12
### 项目打包与运行配置 要在虚拟机中使用 Spark 运行指定的 Scala 类 `ads.ads_car_brand_cut`,并确保 Maven 项目能够正确打包出包含所有依赖的可执行 JAR 文件,需完成以下配置和操作。 #### 1. Maven 项目结构与依赖配置 确保 `pom.xml` 文件中正确配置了 Spark 和 Scala 的依赖,并且 Scala 版本与 Spark 版本兼容。例如,若使用 Spark 3.3.0,则应使用 Scala 2.12,确保版本一致,否则会出现类加载问题[^2]。 ```xml <properties> <scala.version>2.12.15</scala.version> <spark.version>3.3.0</spark.version> </properties> <dependencies> <!-- Spark Core --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-core_${scala.version}</artifactId> <version>${spark.version}</version> </dependency> <!-- Spark SQL --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-sql_${scala.version}</artifactId> <version>${spark.version}</version> </dependency> <!-- Hive 支持 --> <dependency> <groupId>org.apache.spark</groupId> <artifactId>spark-hive_${scala.version}</artifactId> <version>${spark.version}</version> </dependency> <!-- ClickHouse JDBC --> <dependency> <groupId>ru.yandex.clickhouse</groupId> <artifactId>clickhouse-jdbc</artifactId> <version>0.3.2</version> </dependency> <!-- MySQL JDBC --> <dependency> <groupId>mysql</groupId> <artifactId>mysql-connector-java</artifactId> <version>8.0.28</version> </dependency> </dependencies> ``` #### 2. 使用 Maven Shade Plugin 打包包含依赖的 JAR 为确保 JAR 文件包含所有依赖项,使用 `maven-shade-plugin` 插件进行打包,并指定主类为 `ads.ads_car_brand_cut`,以便在运行时直接调用该类[^3]。 ```xml <build> <plugins> <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-shade-plugin</artifactId> <version>3.2.4</version> <executions> <execution> <phase>package</phase> <goals><goal>shade</goal></goals> <configuration> <transformers> <transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"> <mainClass>ads.ads_car_brand_cut</mainClass> </transformer> </transformers> </configuration> </execution> </executions> </plugin> </plugins> </build> ``` 执行以下命令进行打包: ```bash mvn clean package ``` 生成的 JAR 文件位于 `target/` 目录下,文件名类似 `spark02-1.0-SNAPSHOT.jar`,其中包含所有依赖项[^3]。 #### 3. 在虚拟机中运行 Spark 任务 将打包好的 JAR 文件上传至虚拟机,并使用 `spark-submit` 命令运行指定类: ```bash spark-submit --class ads.ads_car_brand_cut \ --master yarn \ --deploy-mode cluster \ target/spark02-1.0-SNAPSHOT.jar ``` 如需指定参数,可添加到命令末尾,例如: ```bash spark-submit --class ads.ads_car_brand_cut \ --master yarn \ --deploy-mode cluster \ target/spark02-1.0-SNAPSHOT.jar \ param1 param2 ``` #### 4. 确保工具类和配置文件正确打包 所有工具类(如 `sparkunits`、`clickhouseunits`、`MySQLunits`、`tableunits`)应放置在 `src/main/scala` 目录下,并在代码中正确引用。配置文件(如数据库连接信息)应放在 `src/main/resources` 下,确保在打包时被包含进 JAR 文件中。 #### 5. 类路径与兼容性检查 在运行过程中若出现类路径问题或类找不到的错误,应检查以下内容: - Spark 版本与依赖库是否兼容,例如 Hudi、ClickHouse JDBC 等[^4]。 - 确保虚拟机上的 Spark 环境中已安装必要的依赖 JAR 包,或通过 `--jars` 参数指定外部依赖: ```bash spark-submit --class ads.ads_car_brand_cut \ --master yarn \ --jars /path/to/hudi-spark-bundle.jar,/path/to/clickhouse-jdbc.jar \ target/spark02-1.0-SNAPSHOT.jar ``` --- ###
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值