UDF、UDAF、UDTF
- UDF:用户自定义函数(User Defined Function)。一行输入一行输出。
- UDAF: 用户自定义聚合函数(User Defined Aggregate Function)。多行输入一行输出。
- UDTF: 用户自定义表函数(User Defined Table Generating Function)。一行输入多行输出。如hive/spark中的explode、json_tuple函数。
UDTF不常用,这里只总结UDF和UDAF。
测试数据
data/student_scores.json
{"id":1, "studentId":111,"language":68,"math":69,"english":90,"classId":"Class1","departmentId":"Economy"}
{"id":2, "studentId":112,"language":73,"math":80,"english":96,"classId":"Class1","departmentId":"Economy"}
{"id":3, "studentId":113,"language":90,"math":74,"english":75,"classId":"Class1","departmentId":"Economy"}
{"id":4, "studentId":114,"language":89,"math":94,"english":93,"classId":"Class1","departmentId":"Economy"}
{"id":5, "studentId":115,"language":99,"math":93,"english":89,"classId":"Class1","departmentId":"Economy"}
{"id":6, "studentId":121,"language":96,"math":74,"english":79,"classId":"Class2","departmentId":"Economy"}
{"id":7, "studentId":122,"language":89,"math":86,"english":85,"classId":"Class2","departmentId":"Economy"}
{"id":8, "studentId":123,"language":70,"math":78,"english":61,"classId":"Class2","departmentId":"Economy"}
{"id":9, "studentId":124,"language":76,"math":70,"english":76,"classId":"Class2","departmentId":"Economy"}
{"id":10,"studentId":211,"language":89,"math":93,"english":60,"classId":"Class1","departmentId":"English"}
{"id":11,"studentId":212,"language":76,"math":83,"english":75,"classId":"Class1","departmentId":"English"}
{"id":12,"studentId":213,"language":71,"math":94,"english":90,"classId":"Class1","departmentId":"English"}
{"id":13,"studentId":214,"language":94,"math":94,"english":66,"classId":"Class1","departmentId":"English"}
{"id":14,"studentId":215,"language":84,"math":82,"english":73,"classId":"Class1","departmentId":"English"}
{"id":15,"studentId":216,"language":85,"math":74,"english":93,"classId":"Class1","departmentId":"English"}
{"id":16,"studentId":221,"language":77,"math":99,"english":61,"classId":"Class2","departmentId":"English"}
{"id":17,"studentId":222,"language":80,"math":78,"english":96,"classId":"Class2","departmentId":"English"}
{"id":18,"studentId":223,"language":79,"math":74,"english":96,"classId":"Class2","departmentId":"English"}
{"id":19,"studentId":224,"language":75,"math":80,"english":78,"classId":"Class2","departmentId":"English"}
{"id":20,"studentId":225,"language":82,"math":85,"english":63,"classId":"Class2","departmentId":"English"}
Spark SQL 与 Spark DataFrame UDF
package com.bigData.spark
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions
/**
* Author: Wang Pei
* License: Copyright(c) Pei.Wang
* Summary:
* UDF:一行输入一行输出。
* spark 2.2.2
*
*/
object SparkDataFrameSqlUDF {
def main(args: Array[String]): Unit = {
//设置日志等级
Logger.getLogger("org").setLevel(Level.WARN)
//spark环境
val spark = SparkSession.builder().appName(this.getClass.getSimpleName).master("local[3]").getOrCreate()
import spark.implicits._
//读取数据
val dataFrame = spark.read.json("data/student_scores.json")
/** 1) Spark DataFrame UDF*/
//UDF注册(普通方法注册成UDF)
val avgScoreUDF=functions.udf[Double,Int,Int,Int](avgScorePerStudent)
//UDF使用
dataFrame.withColumn("avgScore",avgScoreUDF($"language",$"math",$"english")).show(false)
/** 2) Spark Sql UDF*/
dataFrame.createOrReplaceTempView("tmp_student_scores")
//UDF注册(普通方法注册成UDF)
spark.udf.register("avgScoreUDF",functions.udf[Double,Int,Int,Int](avgScorePerStudent))
//使用UDF
spark.sql("select *,avgScoreUDF(language,math,english) as avgScore from tmp_student_scores").show()
}
/**UDF:一行输入一行输出。 这里,得到每个学生(每条数据)平均成绩。*/
def avgScorePerStudent(language:Int,math:Int,english:Int):Double={
((language+math+english)/3.0).formatted("%.2f").toDouble
}
}
Spark SQL 与 Spark DataFrame UDAF
package com.bigData.spark
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
/**
* Author: Wang Pei
* License: Copyright(c) Pei.Wang
* Summary:
* UDAF:多行输入一行输出。
* spark 2.2.2
*
*/
object SparkDataFrameSqlUDAF {
def main(args: Array[String]): Unit = {
//设置日志等级
Logger.getLogger("org").setLevel(Level.WARN)
//spark环境
val spark = SparkSession.builder().appName(this.getClass.getSimpleName).master("local[3]").getOrCreate()
import spark.implicits._
//读取数据
val dataFrame = spark.read.json("data/student_scores.json")
/** 1) Spark DataFrame UDAF*/
//UDAF注册
val avgScorePerClassUDAF = new AvgScorePerClass[Double](IntegerType)
//UDAF使用
dataFrame.groupBy($"departmentId",$"classId").agg(avgScorePerClassUDAF($"language").as("avgScorePerClass")).show(false)
/** 2) Spark Sql UDAF*/
dataFrame.createOrReplaceTempView("tmp_student_scores")
//UDAF注册
spark.udf.register("avgScorePerClassUDAF",new AvgScorePerClass[Double](IntegerType))
//UDAF使用
spark.sql("select departmentId,classId,avgScorePerClassUDAF(language) as avgScorePerClass from tmp_student_scores group by departmentId,classId").show(false)
}
}
/**UDAF:多行输入一行输出。一般和group by连用。这里,计算每个系(departmentId),每个班(classId),语文(language)平均成绩。*/
/**UDAF实现:继承类UserDefinedAggregateFunction,重写方法。*/
class AvgScorePerClass[T](valType: DataType) extends UserDefinedAggregateFunction{
//Input Type
override def inputSchema: StructType = {
new StructType().add("value",IntegerType)
}
//Buffer Type
override def bufferSchema: StructType = {
new StructType().add("sum",IntegerType).add("count",IntegerType)
}
//Return Type
override def dataType: DataType = DoubleType
//Deterministic
override def deterministic: Boolean = true
//Initializes Aggregation Buffer
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0
buffer(1)=0
}
//Update Aggregation Buffer
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0)=buffer.getAs[Int](0)+input.getAs[Int](0)
buffer(1)=buffer.getAs[Int](1)+1
}
//Merge Aggregation Buffer
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
buffer1(1) = buffer1.getAs[Int](1) + buffer2.getAs[Int](1)
}
//Calculate Aggregation Result
override def evaluate(buffer: Row): Any = {
buffer.getAs[Int](0).toDouble/buffer.getAs[Int](1)
}
}