Spark SQL API 大全详解
本文全面解析 Spark SQL 的核心 API,涵盖 DataFrame、Dataset、SQL 函数及高级特性,结合代码示例展示最佳实践。基于 Spark 3.4+ 版本。
一、核心入口:SparkSession
1. 创建与配置
import org.apache.spark.sql.SparkSession
// 基础创建
val spark = SparkSession.builder()
.appName("SparkSQL Demo")
.master("local[*]")
.getOrCreate()
// 高级配置
val spark = SparkSession.builder()
.appName("Production App")
.config("spark.sql.shuffle.partitions", "200")
.config("spark.sql.adaptive.enabled", "true")
.enableHiveSupport() // 启用 Hive 支持
.getOrCreate()
2. 重要方法
// 读取数据
val df = spark.read.format("parquet").load("path/to/data")
// 执行 SQL
spark.sql("SELECT * FROM table")
// 创建临时视图
df.createOrReplaceTempView("temp_view")
// 获取配置
val shufflePartitions = spark.conf.get("spark.sql.shuffle.partitions")
二、DataFrame API
1. 数据读取与写入
// 读取各种格式
val parquetDF = spark.read.parquet("data.parquet")
val csvDF = spark.read.option("header", true).csv("data.csv")
val jsonDF = spark.read.json("data.json")
val jdbcDF = spark.read.format("jdbc")
.option("url", "jdbc:postgresql://localhost/db")
.option("dbtable", "table")
.load()
// 写入操作
df.write.mode("overwrite").parquet("output.parquet")
df.write.format("delta").save("delta_table")
df.write.partitionBy("date").bucketBy(4, "id").saveAsTable("bucketed_table")
2. 数据转换操作
// 列操作
val transformed = df
.select($"name", $"age" + 1 as "age_plus_one") // 选择列
.filter($"age" > 18) // 过滤
.withColumn("is_adult", when($"age" >= 18, true).otherwise(false)) // 新增列
.drop("temp_column") // 删除列
// 分组聚合
val aggDF = df.groupBy("department")
.agg(
avg("salary").alias("avg_salary"),
count("*").alias("emp_count")
)
// 排序
val sortedDF = df.orderBy(desc("salary"), asc("name"))
3. 关联操作
// 内连接
val innerJoin = df1.join(df2, df1("id") === df2("id"))
// 左外连接
val leftJoin = df1.join(df2, Seq("id"), "left_outer")
// 复杂连接
val complexJoin = df1.join(df2,
df1("dept_id") === df2("id") && df1("status") === "active",
"left_semi"
)
4. 窗口函数
import org.apache.spark.sql.expressions.Window
val windowSpec = Window.partitionBy("department").orderBy("salary")
val rankedDF = df.withColumn("rank", rank().over(windowSpec))
.withColumn("dense_rank", dense_rank().over(windowSpec))
.withColumn("row_number", row_number().over(windowSpec))
三、Dataset API(类型安全)
1. 基本操作
case class Person(name: String, age: Int, salary: Double)
// 创建 Dataset
val personDS = spark.read.json("people.json").as[Person]
// 类型安全操作
val adultsDS = personDS.filter(p => p.age >= 18)
val avgSalaryDS = personDS.map(p => (p.department, p.salary))
.groupByKey(_._1)
.agg(avg(_._2).alias("avg_salary"))
2. 强类型函数
import org.apache.spark.sql.functions.udf
// 定义 UDF
val toUpper = udf((s: String) => s.toUpperCase)
// 使用 UDF
df.withColumn("upper_name", toUpper($"name"))
// 类型安全的 UDF
val salaryBonus = udf((salary: Double, bonus: Double) => salary + bonus)
四、SQL 函数大全
1. 字符串函数
import org.apache.spark.sql.functions._
df.select(
concat_ws(" ", $"first_name", $"last_name").as("full_name"),
substring($"email", 1, 5).as("email_prefix"),
regexp_extract($"url", "(https?://[^/]+)", 1).as("domain"),
length($"description").as("desc_length")
)
2. 日期时间函数
df.select(
current_date().as("today"),
date_format($"order_date", "yyyy-MM").as("order_month"),
datediff(current_date(), $"birth_date").as("age_in_days"),
add_months($"start_date", 12).as("anniversary"),
from_unixtime(unix_timestamp()).as("current_time")
)
3. 数学函数
df.select(
round($"salary" * 1.1, 2).as("increased_salary"),
sqrt($"area").as("side_length"),
log("e", $"value").as("ln_value"),
rand(seed=123).as("random_value")
)
4. 聚合函数
df.groupBy("department")
.agg(
countDistinct("employee_id").as("distinct_employees"),
percentile_approx("salary", lit(0.5), lit(100)).as("median_salary"),
collect_list("project").as("all_projects"),
sum(when($"status" === "completed", 1).otherwise(0)).as("completed_count")
)
5. 高级函数
// JSON 处理
df.select(
get_json_object($"json_col", "$.address.city").as("city"),
to_json(struct($"name", $"age")).as("json_string")
)
// 数组处理
df.select(
array_contains($"tags", "urgent").as("is_urgent"),
explode($"items").as("single_item"),
size($"phone_numbers").as("phone_count")
)
// Map 处理
df.select(
map_keys($"properties").as("property_names"),
$"properties".getItem("color").as("color")
)
五、高级 API 功能
1. 用户自定义函数(UDF)
// 标量 UDF
spark.udf.register("geo_distance", (lat1: Double, lon1: Double, lat2: Double, lon2: Double) => {
// Haversine 距离计算实现
val R = 6371 // 地球半径(km)
val dLat = Math.toRadians(lat2 - lat1)
val dLon = Math.toRadians(lon2 - lon1)
val a = Math.sin(dLat/2) * Math.sin(dLat/2) +
Math.cos(Math.toRadians(lat1)) * Math.cos(Math.toRadians(lat2)) *
Math.sin(dLon/2) * Math.sin(dLon/2)
val c = 2 * Math.atan2(Math.sqrt(a), Math.sqrt(1-a))
R * c
})
// 聚合 UDF
class Average extends UserDefinedAggregateFunction {
def inputSchema: StructType = new StructType().add("value", DoubleType)
def bufferSchema: StructType = new StructType().add("sum", DoubleType).add("count", LongType)
def dataType: DataType = DoubleType
def deterministic: Boolean = true
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0.0
buffer(1) = 0L
}
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getDouble(0) + input.getDouble(0)
buffer(1) = buffer.getLong(1) + 1
}
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
def evaluate(buffer: Row): Double = buffer.getDouble(0) / buffer.getLong(1)
}
val customAvg = new Average
df.groupBy("category").agg(customAvg($"price").as("avg_price"))
2. 向量化 Pandas UDF
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType
@pandas_udf(DoubleType())
def predict_udf(model, features: pd.Series) -> pd.Series:
# 使用预训练模型进行批量预测
return pd.Series(model.predict(features.to_numpy()))
# 应用 UDF
df.withColumn("prediction", predict_udf(lit(model), struct("feature1", "feature2")))
3. 高阶函数
// 转换数组
df.select(
transform($"items", item => item * 0.9).as("discounted_items")
)
// 过滤数组
df.select(
filter($"tags", tag => tag.startsWith("imp_")).as("important_tags")
)
// 聚合数组
df.select(
aggregate($"values", lit(0.0), (acc, x) => acc + x).as("total_sum")
)
六、结构化流处理 API
1. 流式读取与处理
val streamDF = spark.readStream
.format("kafka")
.option("kafka.bootstrap.servers", "broker:9092")
.option("subscribe", "topic")
.load()
val processedStream = streamDF
.select(from_json($"value".cast("string"), schema).as("data"))
.select("data.*")
.withWatermark("event_time", "10 minutes")
.groupBy(window($"event_time", "5 minutes"), $"device_id")
.agg(count("*").as("event_count"))
2. 流式输出
processedStream.writeStream
.outputMode("update")
.format("delta")
.option("checkpointLocation", "/checkpoint/dir")
.trigger(Trigger.ProcessingTime("1 minute"))
.start("path/to/delta_table")
七、元数据操作 API
1. Catalog 操作
val catalog = spark.catalog
// 数据库操作
catalog.createDatabase("new_db", Map("comment" -> "Test database"))
catalog.setCurrentDatabase("new_db")
// 表操作
catalog.createTable("employees", "parquet", Map("path" -> "/data/employees"))
catalog.listTables().show()
catalog.refreshTable("sales") // 刷新元数据
// 函数管理
catalog.createFunction("distance", "com.geo.udf.DistanceCalculator")
catalog.listFunctions().filter(_.name.contains("dist")).show()
八、性能优化 API
1. 缓存与持久化
// 缓存策略
df.persist(StorageLevel.MEMORY_AND_DISK_SER)
df.cache() // 默认 MEMORY_AND_DISK
// 检查存储状态
println(df.storageLevel)
// 释放缓存
df.unpersist()
2. 优化提示
// 重分区提示
val repartitioned = df.hint("REPARTITION", 100)
// Join 策略提示
val optimizedJoin = df1.join(df2.hint("BROADCAST"), "id")
// 倾斜处理提示
val skewJoin = df1.join(df2.hint("SKEW", "id"), "id")
九、Spark 3.x 新特性 API
1. Adaptive Query Execution (AQE)
// 启用 AQE
spark.conf.set("spark.sql.adaptive.enabled", true)
// 合并小分区
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", true)
spark.conf.set("spark.sql.adaptive.coalescePartitions.minPartitionSize", "16MB")
// 优化 Join
spark.conf.set("spark.sql.adaptive.optimizeSkewsInRebalancePartitions.enabled", true)
2. Delta Lake 集成
// 创建 Delta 表
DeltaTable.create(spark)
.tableName("events")
.addColumn("event_id", "BIGINT")
.addColumn("event_time", "TIMESTAMP")
.partitionedBy("date")
.execute()
// 时间旅行
spark.read.format("delta")
.option("versionAsOf", 12)
.load("/delta/events")
3. 增强的 Python API
# 类型提示支持
from pyspark.sql.types import StructType, StructField, StringType
schema = StructType([
StructField("name", StringType(), True),
StructField("city", StringType(), True)
])
# 增强的错误信息
spark.sql("SELECT * FROM non_existent_table") # 显示详细表不存在信息
最佳实践总结
-
优先使用 DataFrame API:利用 Catalyst 优化器优势
-
类型安全场景用 Dataset:复杂业务逻辑使用强类型 API
-
函数选择原则:
- 内置函数 > UDF > Pandas UDF
- 列式操作 > 行式操作
-
流处理注意事项:
- 设置合理的水位线
- 启用检查点保证容错
-
性能优化关键点:
// 启用 AQE spark.conf.set("spark.sql.adaptive.enabled", true) // 使用列式存储 df.write.format("parquet").save(...) // 合理分区 df.repartition(200, $"date") -
API 版本兼容:
// 使用官方稳定 API import org.apache.spark.sql.functions.{col, lit} // 避免内部私有 API // import org.apache.spark.sql.catalyst.expressions._ ❌
Spark SQL API 持续演进,建议定期查阅 Spark 官方文档 获取最新特性。
1041

被折叠的 条评论
为什么被折叠?



