Spark SQL 与 Spark DataFrame 自定义UDF、UDAF

本文深入探讨了Spark中的用户自定义函数(UDF)和用户自定义聚合函数(UDAF)的使用方法及实践案例。通过具体的数据集和代码示例,详细讲解了如何在Spark DataFrame和SQL中注册并应用UDF与UDAF,以计算学生平均成绩等统计指标。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

UDF、UDAF、UDTF

  • UDF:用户自定义函数(User Defined Function)。一行输入一行输出。
  • UDAF: 用户自定义聚合函数(User Defined Aggregate Function)。多行输入一行输出。
  • UDTF: 用户自定义表函数(User Defined Table Generating Function)。一行输入多行输出。如hive/spark中的explode、json_tuple函数。

UDTF不常用,这里只总结UDF和UDAF。

测试数据

data/student_scores.json


{"id":1, "studentId":111,"language":68,"math":69,"english":90,"classId":"Class1","departmentId":"Economy"}
{"id":2, "studentId":112,"language":73,"math":80,"english":96,"classId":"Class1","departmentId":"Economy"}
{"id":3, "studentId":113,"language":90,"math":74,"english":75,"classId":"Class1","departmentId":"Economy"}
{"id":4, "studentId":114,"language":89,"math":94,"english":93,"classId":"Class1","departmentId":"Economy"}
{"id":5, "studentId":115,"language":99,"math":93,"english":89,"classId":"Class1","departmentId":"Economy"}
{"id":6, "studentId":121,"language":96,"math":74,"english":79,"classId":"Class2","departmentId":"Economy"}
{"id":7, "studentId":122,"language":89,"math":86,"english":85,"classId":"Class2","departmentId":"Economy"}
{"id":8, "studentId":123,"language":70,"math":78,"english":61,"classId":"Class2","departmentId":"Economy"}
{"id":9, "studentId":124,"language":76,"math":70,"english":76,"classId":"Class2","departmentId":"Economy"}
{"id":10,"studentId":211,"language":89,"math":93,"english":60,"classId":"Class1","departmentId":"English"}
{"id":11,"studentId":212,"language":76,"math":83,"english":75,"classId":"Class1","departmentId":"English"}
{"id":12,"studentId":213,"language":71,"math":94,"english":90,"classId":"Class1","departmentId":"English"}
{"id":13,"studentId":214,"language":94,"math":94,"english":66,"classId":"Class1","departmentId":"English"}
{"id":14,"studentId":215,"language":84,"math":82,"english":73,"classId":"Class1","departmentId":"English"}
{"id":15,"studentId":216,"language":85,"math":74,"english":93,"classId":"Class1","departmentId":"English"}
{"id":16,"studentId":221,"language":77,"math":99,"english":61,"classId":"Class2","departmentId":"English"}
{"id":17,"studentId":222,"language":80,"math":78,"english":96,"classId":"Class2","departmentId":"English"}
{"id":18,"studentId":223,"language":79,"math":74,"english":96,"classId":"Class2","departmentId":"English"}
{"id":19,"studentId":224,"language":75,"math":80,"english":78,"classId":"Class2","departmentId":"English"}
{"id":20,"studentId":225,"language":82,"math":85,"english":63,"classId":"Class2","departmentId":"English"}

Spark SQL 与 Spark DataFrame UDF

package com.bigData.spark

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions


/**
  * Author: Wang Pei
  * License: Copyright(c) Pei.Wang
  * Summary:
  *  UDF:一行输入一行输出。
  *  spark 2.2.2
  *
  */
object SparkDataFrameSqlUDF {
  def main(args: Array[String]): Unit = {

    //设置日志等级
    Logger.getLogger("org").setLevel(Level.WARN)

    //spark环境
    val spark = SparkSession.builder().appName(this.getClass.getSimpleName).master("local[3]").getOrCreate()
    import spark.implicits._

    //读取数据
    val dataFrame = spark.read.json("data/student_scores.json")

    /** 1) Spark DataFrame UDF*/
    //UDF注册(普通方法注册成UDF)
    val avgScoreUDF=functions.udf[Double,Int,Int,Int](avgScorePerStudent)
    //UDF使用
    dataFrame.withColumn("avgScore",avgScoreUDF($"language",$"math",$"english")).show(false)

    /** 2) Spark Sql UDF*/
    dataFrame.createOrReplaceTempView("tmp_student_scores")
    //UDF注册(普通方法注册成UDF)
    spark.udf.register("avgScoreUDF",functions.udf[Double,Int,Int,Int](avgScorePerStudent))
    //使用UDF
    spark.sql("select *,avgScoreUDF(language,math,english) as avgScore from tmp_student_scores").show()

  }

  /**UDF:一行输入一行输出。 这里,得到每个学生(每条数据)平均成绩。*/
  def avgScorePerStudent(language:Int,math:Int,english:Int):Double={
    ((language+math+english)/3.0).formatted("%.2f").toDouble
  }

}

Spark SQL 与 Spark DataFrame UDAF

package com.bigData.spark

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
  * Author: Wang Pei
  * License: Copyright(c) Pei.Wang
  * Summary:
  *   UDAF:多行输入一行输出。
  *   spark 2.2.2
  *
  */
object SparkDataFrameSqlUDAF {
  def main(args: Array[String]): Unit = {


    //设置日志等级
    Logger.getLogger("org").setLevel(Level.WARN)

    //spark环境
    val spark = SparkSession.builder().appName(this.getClass.getSimpleName).master("local[3]").getOrCreate()
    import spark.implicits._

    //读取数据
    val dataFrame = spark.read.json("data/student_scores.json")

    /** 1) Spark DataFrame UDAF*/
    //UDAF注册
    val avgScorePerClassUDAF = new AvgScorePerClass[Double](IntegerType)
    //UDAF使用
    dataFrame.groupBy($"departmentId",$"classId").agg(avgScorePerClassUDAF($"language").as("avgScorePerClass")).show(false)

    /** 2) Spark Sql UDAF*/
    dataFrame.createOrReplaceTempView("tmp_student_scores")
    //UDAF注册
    spark.udf.register("avgScorePerClassUDAF",new AvgScorePerClass[Double](IntegerType))
    //UDAF使用
    spark.sql("select departmentId,classId,avgScorePerClassUDAF(language) as avgScorePerClass from tmp_student_scores group by departmentId,classId").show(false)

  }

}

/**UDAF:多行输入一行输出。一般和group by连用。这里,计算每个系(departmentId),每个班(classId),语文(language)平均成绩。*/
/**UDAF实现:继承类UserDefinedAggregateFunction,重写方法。*/
class AvgScorePerClass[T](valType: DataType) extends UserDefinedAggregateFunction{

  //Input Type
  override def inputSchema: StructType = {
    new StructType().add("value",IntegerType)
  }

  //Buffer Type
  override def bufferSchema: StructType = {
    new StructType().add("sum",IntegerType).add("count",IntegerType)
  }

  //Return Type
  override def dataType: DataType = DoubleType

  //Deterministic
  override def deterministic: Boolean = true

  //Initializes Aggregation Buffer
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0)=0
    buffer(1)=0
  }

  //Update Aggregation Buffer
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(0)=buffer.getAs[Int](0)+input.getAs[Int](0)
    buffer(1)=buffer.getAs[Int](1)+1
  }

  //Merge Aggregation Buffer
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
    buffer1(1) = buffer1.getAs[Int](1) + buffer2.getAs[Int](1)
  }

  //Calculate Aggregation Result
  override def evaluate(buffer: Row): Any = {
    buffer.getAs[Int](0).toDouble/buffer.getAs[Int](1)
  }
}


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值