Spark SQL API 大全详解

Spark SQL API 大全详解

本文全面解析 Spark SQL 的核心 API,涵盖 DataFrame、Dataset、SQL 函数及高级特性,结合代码示例展示最佳实践。基于 Spark 3.4+ 版本。

一、核心入口:SparkSession

1. 创建与配置

import org.apache.spark.sql.SparkSession

// 基础创建
val spark = SparkSession.builder()
  .appName("SparkSQL Demo")
  .master("local[*]")
  .getOrCreate()

// 高级配置
val spark = SparkSession.builder()
  .appName("Production App")
  .config("spark.sql.shuffle.partitions", "200")
  .config("spark.sql.adaptive.enabled", "true")
  .enableHiveSupport()  // 启用 Hive 支持
  .getOrCreate()

2. 重要方法

// 读取数据
val df = spark.read.format("parquet").load("path/to/data")

// 执行 SQL
spark.sql("SELECT * FROM table")

// 创建临时视图
df.createOrReplaceTempView("temp_view")

// 获取配置
val shufflePartitions = spark.conf.get("spark.sql.shuffle.partitions")

二、DataFrame API

1. 数据读取与写入

// 读取各种格式
val parquetDF = spark.read.parquet("data.parquet")
val csvDF = spark.read.option("header", true).csv("data.csv")
val jsonDF = spark.read.json("data.json")
val jdbcDF = spark.read.format("jdbc")
  .option("url", "jdbc:postgresql://localhost/db")
  .option("dbtable", "table")
  .load()

// 写入操作
df.write.mode("overwrite").parquet("output.parquet")
df.write.format("delta").save("delta_table")
df.write.partitionBy("date").bucketBy(4, "id").saveAsTable("bucketed_table")

2. 数据转换操作

// 列操作
val transformed = df
  .select($"name", $"age" + 1 as "age_plus_one") // 选择列
  .filter($"age" > 18)                          // 过滤
  .withColumn("is_adult", when($"age" >= 18, true).otherwise(false)) // 新增列
  .drop("temp_column")                          // 删除列

// 分组聚合
val aggDF = df.groupBy("department")
  .agg(
    avg("salary").alias("avg_salary"),
    count("*").alias("emp_count")
  )

// 排序
val sortedDF = df.orderBy(desc("salary"), asc("name"))

3. 关联操作

// 内连接
val innerJoin = df1.join(df2, df1("id") === df2("id"))

// 左外连接
val leftJoin = df1.join(df2, Seq("id"), "left_outer")

// 复杂连接
val complexJoin = df1.join(df2, 
  df1("dept_id") === df2("id") && df1("status") === "active",
  "left_semi"
)

4. 窗口函数

import org.apache.spark.sql.expressions.Window

val windowSpec = Window.partitionBy("department").orderBy("salary")

val rankedDF = df.withColumn("rank", rank().over(windowSpec))
  .withColumn("dense_rank", dense_rank().over(windowSpec))
  .withColumn("row_number", row_number().over(windowSpec))

三、Dataset API(类型安全)

1. 基本操作

case class Person(name: String, age: Int, salary: Double)

// 创建 Dataset
val personDS = spark.read.json("people.json").as[Person]

// 类型安全操作
val adultsDS = personDS.filter(p => p.age >= 18)
val avgSalaryDS = personDS.map(p => (p.department, p.salary))
  .groupByKey(_._1)
  .agg(avg(_._2).alias("avg_salary"))

2. 强类型函数

import org.apache.spark.sql.functions.udf

// 定义 UDF
val toUpper = udf((s: String) => s.toUpperCase)

// 使用 UDF
df.withColumn("upper_name", toUpper($"name"))

// 类型安全的 UDF
val salaryBonus = udf((salary: Double, bonus: Double) => salary + bonus)

四、SQL 函数大全

1. 字符串函数

import org.apache.spark.sql.functions._

df.select(
  concat_ws(" ", $"first_name", $"last_name").as("full_name"),
  substring($"email", 1, 5).as("email_prefix"),
  regexp_extract($"url", "(https?://[^/]+)", 1).as("domain"),
  length($"description").as("desc_length")
)

2. 日期时间函数

df.select(
  current_date().as("today"),
  date_format($"order_date", "yyyy-MM").as("order_month"),
  datediff(current_date(), $"birth_date").as("age_in_days"),
  add_months($"start_date", 12).as("anniversary"),
  from_unixtime(unix_timestamp()).as("current_time")
)

3. 数学函数

df.select(
  round($"salary" * 1.1, 2).as("increased_salary"),
  sqrt($"area").as("side_length"),
  log("e", $"value").as("ln_value"),
  rand(seed=123).as("random_value")
)

4. 聚合函数

df.groupBy("department")
  .agg(
    countDistinct("employee_id").as("distinct_employees"),
    percentile_approx("salary", lit(0.5), lit(100)).as("median_salary"),
    collect_list("project").as("all_projects"),
    sum(when($"status" === "completed", 1).otherwise(0)).as("completed_count")
  )

5. 高级函数

// JSON 处理
df.select(
  get_json_object($"json_col", "$.address.city").as("city"),
  to_json(struct($"name", $"age")).as("json_string")
)

// 数组处理
df.select(
  array_contains($"tags", "urgent").as("is_urgent"),
  explode($"items").as("single_item"),
  size($"phone_numbers").as("phone_count")
)

// Map 处理
df.select(
  map_keys($"properties").as("property_names"),
  $"properties".getItem("color").as("color")
)

五、高级 API 功能

1. 用户自定义函数(UDF)

// 标量 UDF
spark.udf.register("geo_distance", (lat1: Double, lon1: Double, lat2: Double, lon2: Double) => {
  // Haversine 距离计算实现
  val R = 6371 // 地球半径(km)
  val dLat = Math.toRadians(lat2 - lat1)
  val dLon = Math.toRadians(lon2 - lon1)
  val a = Math.sin(dLat/2) * Math.sin(dLat/2) +
          Math.cos(Math.toRadians(lat1)) * Math.cos(Math.toRadians(lat2)) * 
          Math.sin(dLon/2) * Math.sin(dLon/2)
  val c = 2 * Math.atan2(Math.sqrt(a), Math.sqrt(1-a))
  R * c
})

// 聚合 UDF
class Average extends UserDefinedAggregateFunction {
  def inputSchema: StructType = new StructType().add("value", DoubleType)
  def bufferSchema: StructType = new StructType().add("sum", DoubleType).add("count", LongType)
  def dataType: DataType = DoubleType
  def deterministic: Boolean = true
  
  def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0.0
    buffer(1) = 0L
  }
  
  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0) = buffer.getDouble(0) + input.getDouble(0)
    buffer(1) = buffer.getLong(1) + 1
  }
  
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }
  
  def evaluate(buffer: Row): Double = buffer.getDouble(0) / buffer.getLong(1)
}

val customAvg = new Average
df.groupBy("category").agg(customAvg($"price").as("avg_price"))

2. 向量化 Pandas UDF

from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType

@pandas_udf(DoubleType())
def predict_udf(model, features: pd.Series) -> pd.Series:
    # 使用预训练模型进行批量预测
    return pd.Series(model.predict(features.to_numpy()))

# 应用 UDF
df.withColumn("prediction", predict_udf(lit(model), struct("feature1", "feature2")))

3. 高阶函数

// 转换数组
df.select(
  transform($"items", item => item * 0.9).as("discounted_items")
)

// 过滤数组
df.select(
  filter($"tags", tag => tag.startsWith("imp_")).as("important_tags")
)

// 聚合数组
df.select(
  aggregate($"values", lit(0.0), (acc, x) => acc + x).as("total_sum")
)

六、结构化流处理 API

1. 流式读取与处理

val streamDF = spark.readStream
  .format("kafka")
  .option("kafka.bootstrap.servers", "broker:9092")
  .option("subscribe", "topic")
  .load()

val processedStream = streamDF
  .select(from_json($"value".cast("string"), schema).as("data"))
  .select("data.*")
  .withWatermark("event_time", "10 minutes")
  .groupBy(window($"event_time", "5 minutes"), $"device_id")
  .agg(count("*").as("event_count"))

2. 流式输出

processedStream.writeStream
  .outputMode("update")
  .format("delta")
  .option("checkpointLocation", "/checkpoint/dir")
  .trigger(Trigger.ProcessingTime("1 minute"))
  .start("path/to/delta_table")

七、元数据操作 API

1. Catalog 操作

val catalog = spark.catalog

// 数据库操作
catalog.createDatabase("new_db", Map("comment" -> "Test database"))
catalog.setCurrentDatabase("new_db")

// 表操作
catalog.createTable("employees", "parquet", Map("path" -> "/data/employees"))
catalog.listTables().show()
catalog.refreshTable("sales")  // 刷新元数据

// 函数管理
catalog.createFunction("distance", "com.geo.udf.DistanceCalculator")
catalog.listFunctions().filter(_.name.contains("dist")).show()

八、性能优化 API

1. 缓存与持久化

// 缓存策略
df.persist(StorageLevel.MEMORY_AND_DISK_SER)
df.cache()  // 默认 MEMORY_AND_DISK

// 检查存储状态
println(df.storageLevel)

// 释放缓存
df.unpersist()

2. 优化提示

// 重分区提示
val repartitioned = df.hint("REPARTITION", 100)

// Join 策略提示
val optimizedJoin = df1.join(df2.hint("BROADCAST"), "id")

// 倾斜处理提示
val skewJoin = df1.join(df2.hint("SKEW", "id"), "id")

九、Spark 3.x 新特性 API

1. Adaptive Query Execution (AQE)

// 启用 AQE
spark.conf.set("spark.sql.adaptive.enabled", true)

// 合并小分区
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", true)
spark.conf.set("spark.sql.adaptive.coalescePartitions.minPartitionSize", "16MB")

// 优化 Join
spark.conf.set("spark.sql.adaptive.optimizeSkewsInRebalancePartitions.enabled", true)

2. Delta Lake 集成

// 创建 Delta 表
DeltaTable.create(spark)
  .tableName("events")
  .addColumn("event_id", "BIGINT")
  .addColumn("event_time", "TIMESTAMP")
  .partitionedBy("date")
  .execute()

// 时间旅行
spark.read.format("delta")
  .option("versionAsOf", 12)
  .load("/delta/events")

3. 增强的 Python API

# 类型提示支持
from pyspark.sql.types import StructType, StructField, StringType

schema = StructType([
  StructField("name", StringType(), True),
  StructField("city", StringType(), True)
])

# 增强的错误信息
spark.sql("SELECT * FROM non_existent_table")  # 显示详细表不存在信息

最佳实践总结

  1. 优先使用 DataFrame API:利用 Catalyst 优化器优势

  2. 类型安全场景用 Dataset:复杂业务逻辑使用强类型 API

  3. 函数选择原则

    • 内置函数 > UDF > Pandas UDF
    • 列式操作 > 行式操作
  4. 流处理注意事项

    • 设置合理的水位线
    • 启用检查点保证容错
  5. 性能优化关键点

    // 启用 AQE
    spark.conf.set("spark.sql.adaptive.enabled", true)
    
    // 使用列式存储
    df.write.format("parquet").save(...)
    
    // 合理分区
    df.repartition(200, $"date")
    
  6. API 版本兼容

    // 使用官方稳定 API
    import org.apache.spark.sql.functions.{col, lit}
    // 避免内部私有 API
    // import org.apache.spark.sql.catalyst.expressions._ ❌
    

Spark SQL API 持续演进,建议定期查阅 Spark 官方文档 获取最新特性。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值