Spark 自定义聚合函数(UDAF)UserDefinedAggregateFunction 原理用法示例源码分析
原理
UserDefinedAggregateFunction 是 Spark SQL 中用于实现用户自定义聚合函数(UDAF)的抽象类。通过继承该类并实现其中的方法,可以创建自定义的聚合函数,并在 Spark SQL 中使用。
UserDefinedAggregateFunction 的原理是基于 Spark SQL 的聚合操作流程。当一个 UDAF 被应用到 DataFrame 上时,Spark SQL 会将 UDAF 转化为一个 AggregateExpression 对象,其中包含了对应的 ScalaUDAF 实例和聚合操作类型。然后,Spark SQL 会对数据进行分组、聚合等操作,并调用 UDAF 中的方法来执行具体的聚合逻辑。
在具体实现中,UserDefinedAggregateFunction 提供了一系列方法,如 inputSchema、bufferSchema、dataType 等,用于定义输入参数的数据类型、缓冲区中值的数据类型以及返回值的数据类型。同时,它还提供了 initialize、update、merge 和 evaluate 方法,用于初始化聚合缓冲区、更新缓冲区、合并缓冲区以及计算最终结果。此外,UserDefinedAggregateFunction 还提供了 apply 和 distinct 方法,用于创建 Column 对象,方便在 DataFrame 中使用自定义聚合函数。
总的来说,UserDefinedAggregateFunction 通过定义一系列方法,使得用户可以灵活地实现自定义的聚合逻辑,并将其应用到 Spark SQL 的聚合操作中。通过这种方式,用户可以扩展 Spark SQL 中的聚合能力,满足特定的业务需求。
用法
| 方法名 | 描述 |
|---|---|
inputSchema | 返回聚合函数的输入参数的数据类型的 StructType。 |
bufferSchema | 返回聚合缓冲区中值的数据类型的 StructType。 |
dataType | 返回聚合函数的返回值的数据类型。 |
deterministic | 返回布尔值,指示此函数是否是确定性的。 |
initialize(buffer) | 初始化给定的聚合缓冲区。 |
update(buffer, input) | 使用新的输入数据更新聚合缓冲区。 |
merge(buffer1, buffer2) | 合并两个聚合缓冲区。 |
evaluate(buffer) | 根据给定的聚合缓冲区计算最终结果。 |
apply(exprs) | 使用给定的 Column 参数创建一个 Column 对象来调用 UDAF。 |
distinct(exprs) | 使用给定的不同值的 Column 参数创建一个 Column 对象来调用 UDAF。 |
update(i, value) | 更新可变聚合缓冲区的第 i 个值。 |
示例
package org.example.spark
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
object AverageVecDemo {
// 创建自定义聚合函数
class MyAverage extends UserDefinedAggregateFunction {
// 输入参数的数据类型
def inputSchema: StructType = new StructType().add("value", DoubleType)
// 聚合缓冲区中值的数据类型
def bufferSchema: StructType = new StructType()
.add("sum", DoubleType)
.add("count", LongType)
// 返回值的数据类型
def dataType: DataType = DoubleType
// 是否是确定性的
def deterministic: Boolean = true
// 初始化聚合缓冲区
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0.0 // sum
buffer(1) = 0L // count
}
// 更新聚合缓冲区
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
val value = input.getDouble(0)
buffer(0) = buffer.getDouble(0) + value
buffer(1) = buffer.getLong(1) + 1
}
}
// 合并两个聚合缓冲区
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// 计算最终结果
def evaluate(buffer: Row): Any = {
buffer.getDouble(0) / buffer.getLong(1)
}
}
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("UDAFDemo")
.master("local[*]")
.getOrCreate()
import spark.implicits._
// 创建一个 DataFrame
val data = Seq(1.0, 2.0, 3.0, 4.0, 5.0).toDF("value")
// 注册自定义聚合函数
spark.udf.register("myAverage", new MyAverage)
// 使用自定义聚合函数进行聚合操作
val result = data.selectExpr("myAverage(value) as average")
result.show()
spark.stop()
}
}
//+-------+
//|average|
//+-------+
//| 3.0|
//+-------+
这个示例中,我们创建了一个自定义聚合函数 MyAverage,用于计算输入数据列的平均值。然后,我们将该函数注册到 Spark 的 UDF(用户定义函数)中,并在 DataFrame 中使用 selectExpr 方法调用它进行聚合操作。最后,我们展示了聚合结果。
源码
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete}
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.types._
/**
* 实现用户自定义聚合函数(UDAF)的基类。
*
* @since 1.5.0
*/
@InterfaceStability.Stable
abstract class UserDefinedAggregateFunction extends Serializable {
/**
* `StructType` 表示此聚合函数的输入参数的数据类型。
* 例如,如果一个[[UserDefinedAggregateFunction]]期望两个输入参数,
* 分别是`DoubleType`和`LongType`类型,返回的`StructType`将如下所示:
*
* ```
* new StructType()
* .add("doubleInput", DoubleType)
* .add("longInput", LongType)
* ```
*
* 此`StructType`的字段名称仅用于标识对应的输入参数。用户可以选择名称以标识输入参数。
*
* @since 1.5.0
*/
def inputSchema: StructType
/**
* `StructType` 表示聚合缓冲区中值的数据类型。
* 例如,如果一个[[UserDefinedAggregateFunction]]的缓冲区有两个值
* (即两个中间值),分别是`DoubleType`和`LongType`类型,
* 返回的`StructType`将如下所示:
*
* ```
* new StructType()
* .add("doubleInput", DoubleType)
* .add("longInput", LongType)
* ```
*
* 此`StructType`的字段名称仅用于标识对应的缓冲区值。用户可以选择名称以标识输入参数。
*
* @since 1.5.0
*/
def bufferSchema: StructType
/**
* [[UserDefinedAggregateFunction]] 返回值的 `DataType`。
*
* @since 1.5.0
*/
def dataType: DataType
/**
* 如果此函数是确定性的,则返回true,即给定相同的输入,总是返回相同的输出。
*
* @since 1.5.0
*/
def deterministic: Boolean
/**
* 初始化给定的聚合缓冲区,即聚合缓冲区的初始值。
*
* 即应用于两个初始缓冲区的合并函数只应返回初始缓冲区本身,即
* `merge(initialBuffer, initialBuffer)` 应等于 `initialBuffer`。
*
* @since 1.5.0
*/
def initialize(buffer: MutableAggregationBuffer): Unit
/**
* 使用来自`input`的新输入数据更新给定的聚合缓冲区`buffer`。
*
* 每行输入调用一次此方法。
*
* @since 1.5.0
*/
def update(buffer: MutableAggregationBuffer, input: Row): Unit
/**
* 合并两个聚合缓冲区,并将更新后的缓冲区值存储回`buffer1`。
*
* 当我们合并两个部分聚合的数据时,会调用此方法。
*
* @since 1.5.0
*/
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit
/**
* 根据给定的聚合缓冲区计算此[[UserDefinedAggregateFunction]]的最终结果。
*
* @since 1.5.0
*/
def evaluate(buffer: Row): Any
/**
* 使用给定的`Column`s作为输入参数创建此UDAF的`Column`。
*
* @since 1.5.0
*/
@scala.annotation.varargs
def apply(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(
ScalaUDAF(exprs.map(_.expr), this),
Complete,
isDistinct = false)
Column(aggregateExpression)
}
/**
* 使用给定的`Column`s的不同值作为输入参数创建此UDAF的`Column`。
*
* @since 1.5.0
*/
@scala.annotation.varargs
def distinct(exprs: Column*): Column = {
val aggregateExpression =
AggregateExpression(
ScalaUDAF(exprs.map(_.expr), this),
Complete,
isDistinct = true)
Column(aggregateExpression)
}
}
/**
* 表示可变聚合缓冲区的`Row`。
*
* 不建议在Spark之外扩展它。
*
* @since 1.5.0
*/
@InterfaceStability.Stable
abstract class MutableAggregationBuffer extends Row {
/** 更新此缓冲区的第i个值。 */
def update(i: Int, value: Any): Unit
}
gregateExpression)
}
}
/**
* 表示可变聚合缓冲区的`Row`。
*
* 不建议在Spark之外扩展它。
*
* @since 1.5.0
*/
@InterfaceStability.Stable
abstract class MutableAggregationBuffer extends Row {
/** 更新此缓冲区的第i个值。 */
def update(i: Int, value: Any): Unit
}
参考链接
https://spark.apache.org/docs/latest/sql-ref-functions-udf-aggregate.html
文章详细介绍了SparkSQL中UserDefinedAggregateFunction(UDAF)的原理,包括其在聚合操作流程中的角色,如何通过继承抽象类并实现相关方法来创建自定义聚合函数,以及如何定义输入、缓冲区和返回值的数据类型。还提供了用法示例和源码分析。
565

被折叠的 条评论
为什么被折叠?



