Spark ML关于模型保存,模型加载案例

本文通过实战演示了如何使用 Apache Spark 的 MLlib 构建一个机器学习流水线,包括文本预处理、特征提取和逻辑回归模型训练。展示了从数据准备到模型训练、保存和预测的完整过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

package com.xy.data.model

import org.apache.spark.SparkConf
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.{Row, SparkSession}

object SparkMLExample {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().config(new SparkConf().setMaster("local[*]")).getOrCreate()
    // Prepare training documents from a list of (id, text, label) tuples.
    val training = spark.createDataFrame(Seq(
      (0L, "a b c d e spark", 1.0),
      (1L, "b d", 0.0),
      (2L, "spark f g h", 1.0),
      (3L, "hadoop mapreduce", 0.0)
    )).toDF("id", "text", "label")

    // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
    val tokenizer = new Tokenizer()
      .setInputCol("text")
      .setOutputCol("words")
    val hashingTF = new HashingTF()
      .setNumFeatures(1000)
      .setInputCol(tokenizer.getOutputCol)
      .setOutputCol("features")
    val lr = new LogisticRegression()
      .setMaxIter(10)
      .setRegParam(0.001)
    val pipeline = new Pipeline()
      .setStages(Array(tokenizer, hashingTF, lr))

    // Fit the pipeline to training documents.
    val model = pipeline.fit(training)

    // Now we can optionally save the fitted pipeline to disk
    model.write.overwrite().save("./spark-logistic-regression-model")

    // We can also save this unfit pipeline to disk
    pipeline.write.overwrite().save("./unfit-lr-model")

    // And load it back in during production
    val sameModel = PipelineModel.load("./spark-logistic-regression-model")

    // Prepare test documents, which are unlabeled (id, text) tuples.
    val test = spark.createDataFrame(Seq(
      (4L, "spark i j k"),
      (5L, "l m n"),
      (6L, "spark hadoop spark"),
      (7L, "apache hadoop")
    )).toDF("id", "text")

    // Make predictions on test documents.
    model.transform(test)
      .select("id", "text", "probability", "prediction")
      .collect()
      .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
        println(s"($id, $text) --> prob=$prob, prediction=$prediction")
      }

  }

}

 

### Apache Spark 使用案例与示例代码 以下是几个常见的 Apache Spark 使用案例及其对应的示例代码: #### 1. 批量数据处理 批量数据处理是 Spark 的核心功能之一。通过 `SparkContext` 创建 RDD 并对其进行各种转换和动作操作。 ```scala val sc = new SparkContext("local", "BatchProcessingExample") val inputRDD = sc.textFile("input.txt") // 转换:将每行按空格分割成单词并统计词频 val words = inputRDD.flatMap(line => line.split("\\s+")) val wordCounts = words.map(word => (word, 1)).reduceByKey(_ + _) // 动作:保存结果到文件 wordCounts.saveAsTextFile("output/wordcount") ``` 此代码展示了如何读取文本文件,将其拆分为单词,并计算每个单词的频率[^1]。 --- #### 2. 结构化数据分析(DataFrame) 使用 Spark SQL 和 DataFrame API 对结构化数据进行高效查询和分析。 ```scala import org.apache.spark.sql.SparkSession val spark = SparkSession.builder() .appName("StructuredDataAnalysis") .master("local[*]") .getOrCreate() // 加载 CSV 文件作为 DataFrame val df = spark.read.format("csv") .option("header", "true") .option("inferSchema", "true") .load("data.csv") // 数据筛选和聚合 df.filter($"age" > 30) .groupBy("gender") .agg(avg("salary"), count("*")) .show() ``` 这段代码演示了如何加载 CSV 文件、应用过滤条件以及执行分组聚合操作[^1]。 --- #### 3. 实时流数据处理(Spark Streaming) 结合 Spark Streaming 处理实时数据流,例如来自 Kafka 或其他消息队列系统的事件。 ```scala import org.apache.spark.streaming._ import org.apache.spark.streaming.kafka._ val ssc = new StreamingContext(sc, Seconds(5)) val kafkaParams = Map[String, String]("metadata.broker.list" -> "localhost:9092") val topicsSet = Set("test-topic") // 订阅 Kafka 主题 val messages = KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder]( ssc, kafkaParams, topicsSet ) // 解析 JSON 消息并提取字段 messages.map(record => record._2.parseJson.asJsObject.getFields("field")) ssc.start() ssc.awaitTermination() ``` 该示例说明了如何从 Kafka 接收实时数据流并解析其内容[^1]。 --- #### 4. 机器学习模型训练(MLlib) 利用 MLlib 构建和部署机器学习模型,例如 K-Means 聚类算法。 ```scala import org.apache.spark.ml.clustering.KMeans import org.apache.spark.sql.functions._ val dataset = spark.read.format("libsvm").load("kmeans_data.txt") // 训练 K-Means 模型 val kmeans = new KMeans().setK(2).setSeed(1L) val model = kmeans.fit(dataset) // 输出聚类中心 model.clusterCenters.foreach(println) ``` 这里展示了一个简单的 K-Means 聚类过程,适合初学者了解 MLlib 的基本用法[^4]。 --- #### 5. 累积器的应用场景 累积器是一种共享变量,通常用于计数或累加特定值的操作。 ```scala val sc = new SparkContext("local", "AccumulatorExample") val lines = sc.textFile("log.txt") val errorCounter = sc.accumulator(0) lines.foreach { line => if (line.contains("ERROR")) { errorCounter += 1 } } println(s"Total number of ERROR logs: ${errorCounter.value}") ``` 这个例子解释了如何使用累积器来统计日志文件中包含关键字 “ERROR” 的记录数量[^5]。 --- ### 性能调优建议 为了提高性能,在实际生产环境中应考虑以下几点: - 合理调整分区数目以平衡负载。 - 缓存频繁访问的数据集以减少重复计算开销。 - 根据硬件资源动态配置 Executor 内存大小和其他参数[^3]。 ---
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值