一、介绍
Spark SQL中自定义函数包括UDF和UDAF
(先前已经发布一篇SparkSQL的UDF函数,现在为大家讲解一下UDAF自定义聚合函数)
自定义函数
UDF:一进一出
UDAF:多进一出 √
二、UDAF函数
UDA:户自定义聚合函数,类似在group by之后使用的sum,avg等。
首先创建class继承接口UserDefineAggregateFunction,并实现其中的方法。
这里的UDAF,则可以针对多行输入,进行聚合计算,返回一个输出,功能更加强大。
package SparkSQL import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ class StringCount extends UserDefinedAggregateFunction{ //指的是输入数据的类型 override def inputSchema: StructType = { StructType(Array(StructField("str",StringType,true))) } //bufferSchema指的是中间进行聚合时,所处理的数据类型 override def bufferSchema: StructType = { StructType(Array(StructField("count",IntegerType,true))) } //dataType指的是函数返回值的类型 override def dataType: DataType = { IntegerType } override def deterministic: Boolean = { true } //为每个分组的数据执行初始化操作 override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0)=0 } //指的是,每个分组有新的值进来的时候,如何进行分组对应的聚合值的计算 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { buffer(0)=buffer.getAs[Int](0)+1 } //由于spark是分布式的。所以一个分组的数据,可能会在不同的节点上进行局部聚合,就是update //但是,最后一个分组会在节点上聚合值,要进行merge,也就是合并 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0)=buffer1.getAs[Int](0)+buffer2.getAs[Int](0) } //最后指的是,一个分组聚合之,如何通过中间的缓存聚合之,最后返回一个最终的聚合值 override def evaluate(buffer: Row): Any = { buffer.getAs[Int](0) } }
编写好class类,接下来编写测试
package SparkSQL import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.{SparkConf, SparkContext} object UDAF { def main(args: Array[String]): Unit = { val conf = new SparkConf().setMaster("local").setAppName("UDF") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) //模拟构造数据 val names = Array("leo","Marry","Jack","Tom","Tom","Tom","leo","leo") val nameRDD=sc.parallelize(names,5) val namesRowRDD=nameRDD.map{name=>Row(name)} val structType = StructType(Array(StructField("name",StringType,true))) val namesDF=sqlContext.createDataFrame(namesRowRDD,structType) //注册一张零时表 namesDF.registerTempTable("names") //定义和注册自定义函数 sqlContext.udf.register("strCount",new StringCount) //使用自定义函数 sqlContext.sql("select name,strCount(name) as a from names group by name order by a desc").collect().foreach(println) // sqlContext.sql("select name,count(*) from names group by name").collect().foreach(println) } }