弱类型用户自定义聚合函数:通过继承UserDefinedAggregateFunction来实现用户自定义聚合函数。下面展示一个求平均年龄的自定义聚合函数
1.extends UserDefinedAggregateFunction
2.实现方法
3.spark.udf.register 注册函数
package com.wxx.bigdata.sql03
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
object CustomerUDAFApp {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("CustomerUDAFApp").master("local[2]").getOrCreate()
val df = spark.read.json("data/test/user.json")
df.createOrReplaceTempView("users")
spark.udf.register("avgAge", MyAverage)
spark.sql("select avgAge(age) from users").show()
spark.stop()
}
}
//自定义求age的平均值
object MyAverage extends UserDefinedAggregateFunction{
//函数输入的数据结构
override def inputSchema: StructType = {
new StructType().add("age", LongType)
}
//计算时的数据结构
override def bufferSchema: StructType = {
new StructType().add("age", LongType).add("count", LongType)
}
// 函数返回的数据类型
override def dataType: DataType = DoubleType
// 函数是都稳定
override def deterministic: Boolean = true
// 计算之前缓冲区的初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L //age
buffer(1) = 0L //count
}
// 根据查询结果更新缓冲区的数据
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
// 将多个节点的缓冲区合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) // sum
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) // count
}
// 计算平均值
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble / buffer.getLong(1)
}
}