SparkSql自定义强类型、弱类型聚合函数

本文详细介绍如何在Spark SQL中自定义弱类型和强类型的聚合函数(UDAF),包括求和(sum)和求平均值(avg)的具体实现,通过案例演示如何在DataFrame和Dataset上应用自定义的聚合函数。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

自定义弱类型

package com.chen.sparksql.func

import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

object UDAFDemo {
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder().appName("UDAFDemo").master("local[2]").getOrCreate()

    val df: DataFrame = spark.read.json("d:/user.json")

    df.createOrReplaceTempView("user")


    // 注册聚合函数
    spark.udf.register("mySum",new MySum)
    spark.udf.register("myAvg",new MyAvg)

    spark.sql("select mySum(salary),myAvg(salary) from user").show

    spark.close()
  }

}

/**
 * 自定义聚合函数 实现sum求和功能
 */
class MySum extends UserDefinedAggregateFunction {
  // 输入数据类型
  override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil)

  // 缓冲区类型
  override def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: Nil)

  // 最终聚合结果的数据类型
  override def dataType: DataType = DoubleType

  // 相同的输入是否返回相同的输出
  override def deterministic: Boolean = true

  // 对缓冲区初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    // 在缓冲集合中初始化
    buffer(0) = 0D // 等同     buffer.update(0,0D)
  }

  // 分区内聚合
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    // input 是指使用聚合函数时,缓存过来的参数封装到了Row中
    if (!input.isNullAt(0)) { // 考虑到每行数据传入时 可能有nill的情况 会报错
      //      val v: Double = input.getDouble(0)
      val v: Double = input.getAs[Double](0)
      buffer(0) = buffer.getDouble(0) + v
    }

  }

  // 分区间聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    // buffer1 和buffer2聚合在一起 之后写回buffer1
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
  }

  // 最终返回值
  override def evaluate(buffer: Row): Any = buffer(0)
}


/**
 * 自定义聚合函数 实现avg 求平均值功能
 */
class MyAvg extends UserDefinedAggregateFunction {
  // 输入数据类型
  override def inputSchema: StructType = StructType(StructField("ele", DoubleType) :: Nil)

  // 缓冲区类型
  override def bufferSchema: StructType = StructType(StructField("sum", DoubleType) :: StructField("count", LongType)::Nil)

  // 最终聚合结果的数据类型
  override def dataType: DataType = DoubleType

  // 相同的输入是否返回相同的输出
  override def deterministic: Boolean = true

  // 对缓冲区初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    // 在缓冲集合中初始化
    buffer(0) = 0D // 等同     buffer.update(0,0D)
    buffer(1) = 0L // count
  }

  // 分区内聚合
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    // input 是指使用聚合函数时,缓存过来的参数封装到了Row中
    if (!input.isNullAt(0)) { // 考虑到每行数据传入时 可能有nill的情况 会报错
      //      val v: Double = input.getDouble(0)
      val v: Double = input.getAs[Double](0)
      buffer(0) = buffer.getDouble(0) + v
      buffer(1) = buffer.getLong(1) + 1L
    }

  }

  // 分区间聚合
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    // buffer1 和buffer2聚合在一起 之后写回buffer1
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
    buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }

  // 最终返回值
  override def evaluate(buffer: Row): Any = buffer.getDouble(0) / buffer.getLong(1)
}

自定义强类型

package com.chen.sparksql.func

import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Encoders, Row, SparkSession, TypedColumn}


case class Dog(name: String, age: Int)

case class AgeAvg(sum: Int, count: Int) {
  def avg = sum.toDouble / count
}

object UDAFDemo2{
  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession.builder().appName("UDAFDemo").master("local[2]").getOrCreate()

    import spark.implicits._
    val ds: Dataset[Dog] = List(Dog("dahuang", 8), Dog("xiaohuang", 4), Dog("zhonghuang", 6)).toDS()

    // 强类型使用方式 需要转成TypedColumn类型
    val avg: TypedColumn[Dog, Double] = new MyAvg2().toColumn.name("abg")
    ds.select(avg).show
    
    spark.close()
  }

}


/**
 * 自定义强类型 聚合函数 实现avg 求平均值功能
 * 可在ds中使用
 */
class MyAvg2 extends Aggregator[Dog, AgeAvg, Double] {
  // 对缓冲区进行初始化
  override def zero: AgeAvg = AgeAvg(0, 0)

  // 分区内聚合
  override def reduce(b: AgeAvg, a: Dog): AgeAvg = a match {
    case Dog(name, age) => AgeAvg(b.sum + age, b.count + 1)
    // 如果为null 直接返回
    case _ => b
  }

  // 分区间聚合
  override def merge(b1: AgeAvg, b2: AgeAvg): AgeAvg = AgeAvg(b1.sum + b2.sum, b1.count + b2.count)

  // 返回最终的值
  override def finish(reduction: AgeAvg): Double = reduction.avg

  // 对缓冲区进行编码
  override def bufferEncoder: Encoder[AgeAvg] = Encoders.product // 若为样例类,直接返回product编码器

  // 对返回值进行编码
  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
Apache Spark 是专为大规模数据处理而设计的快速通用的计算引擎。Spark是UC Berkeley AMP lab (加州大学伯克利分校的AMP实验室)所开源的类Hadoop MapReduce的通用并行框架,Spark,拥有Hadoop MapReduce所具有的优点;但不同于MapReduce的是——Job中间输出结果可以保存在内存中,从而不再需要读写HDFS,因此Spark能更好地适用于数据挖掘与机器学习等需要迭代的MapReduce的算法。Spark 是一种与 Hadoop 相似的开源集群计算环境,但是两者之间还存在一些不同之处,这些有用的不同之处使 Spark 在某些工作负载方面表现得更加优越,换句话说,Spark 启用了内存分布数据集,除了能够提供交互式查询外,它还可以优化迭代工作负载。Spark 是在 Scala 语言中实现的,它将 Scala 用作其应用程序框架。与 Hadoop 不同,Spark 和 Scala 能够紧密集成,其中的 Scala 可以像操作本地集合对象一样轻松地操作分布式数据集。尽管创建 Spark 是为了支持分布式数据集上的迭代作业,但是实际上它是对 Hadoop 的补充,可以在 Hadoop 文件系统中并行运行。通过名为 Mesos 的第三方集群框架可以支持此行为。Spark 由加州大学伯克利分校 AMP 实验室 (Algorithms, Machines, and People Lab) 开发,可用来构建大型的、低延迟的数据分析应用程序。本部分内容全面涵盖了Spark生态系统的概述及其编程模型,深入内核的研究,Spark on Yarn,Spark RDD、Spark Streaming流式计算原理与实践,Spark SQL,Spark的多语言编程以及SparkR的原理和运行。本套Spark教程不仅面向项目开发人员,甚至对于研究Spark的在校学员,都是非常值得学习的。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值