SparkSql 中用户自定义聚合函数---弱类型

        强类型的Dataset和弱类型的DataFrame都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数。

       弱类型用户自定义聚合函数:通过继承UserDefinedAggregateFunction来实现用户自定义聚合函数。

     弱类型 Demo 自定义求用户平均年龄的聚合函数。

package com.bigdata.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

object SparkSQL05_UDAF {
  def main(args: Array[String]): Unit = {

    // c创建conf文件
    val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("sql")

    //  创建SparkSession

    val spark: SparkSession = SparkSession.builder().config(sparkConf).getOrCreate()

    // 导入隐式转换

    import spark.implicits._

    // 自定义聚合函数
 //  创建聚合函数对象
    val udaf = new MyageAvgFunction

    // 注册聚合函数
    spark.udf.register("Avgage",udaf)
    // 创建RDD
    val rdd = spark.sparkContext.makeRDD(List((1,"zhangsan",20), (2,"lisi", 30), (3,"wangwu",40)))
    // 转换为DF

    val df: DataFrame = rdd.toDF("id","name","age")

    // 创建视图

    df.createOrReplaceTempView("user")

    //  使用聚合函数

    spark.sql("select Avgage(age) avgage from user").show()


   // 关闭资源
   spark.stop()

  }

}


// 声明用户自定义的聚合函数
// 1) 继承UserDefinedAggregateFunction
// 2) 实现方法

class MyageAvgFunction extends UserDefinedAggregateFunction{

  // 函数输入时的数据结构
  override def inputSchema: StructType = {
    new StructType().add("age",LongType)
  }

  // 计算时的数据结构
  override def bufferSchema: StructType = {

    new StructType().add("sum",LongType).add("count",LongType)

  }

  // 函数返回时的数据类型
  override def dataType: DataType = DoubleType

  // 函数是否稳定 (给相同的数据返回结果应该是一样的)
  override def deterministic: Boolean = true

// 计算之之前的缓冲区的数据初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0) = 0L //sum
    buffer(1) = 0L   // count
  }


// 根据查询结果更新缓冲区
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
  buffer(0) = buffer.getLong(0)+input.getLong(0)
    buffer(1) = buffer.getLong(1)+1


  }

  // 多个节点缓冲区合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

    buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)

  }


  // 计算最终的结果
  override def evaluate(buffer: Row): Any = {

    buffer.getLong(0)/buffer.getLong(1).toDouble

  }
}

     计算结果:

                           

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值