UDAF使用
需要处理的数据:
{"name":"Michael", "salary":3000}
{"name":"Andy", "salary":4500}
{"name":"Justin", "salary":3500}
{"name":"Berta", "salary":4000}
- 读取文件
object MyUdafaa {
def main(args: Array[String]): Unit = {
val conf: SparkConf = new SparkConf().setAppName("udaf").setMaster("local[2]")
val sc = new SparkContext(conf)
val sct: SQLContext = new SQLContext(sc)
// 读取文件
val file: DataFrame = sct.read.json("file.txt")
// 创建视图,并把读取的数据放进视图表里
file.registerTempTable("student")
// 创建函数,函数名,方法名(可以随意命名)
sct.udf.register("udaf",StringCount)
// 查询视图表里的所有数据
sct.sql("select * from student").show()
/* +-------+------+
| name|salary|
+-------+------+
|Michael| 3000|
| Andy| 4500|
| Justin| 3500|
| Berta| 4000|
+-------+------+ 这是打印出来的数据,还没有经过函数的调用 */
// 这里使用函数
sct.sql("select udaf(salary) as avga from student").show()
/* +------+
| avga|
+------+
|3750.0|
+------+ 这是处理后的结果*/
sc.stop()
}
}
- 继承UDAF(UserDefinedAggregateFunction)
object StringCount extends UserDefinedAggregateFunction{
// 输入的数据类型
override def inputSchema: StructType =
// 这是传入的数据类型,StructField真正的数据类型, ::Nil是创建一个list列表
StructType(StructField("salary",IntegerType)::Nil)
// 中间聚合处理时,需要处理的数据类型,中间聚合处理时,需要处理的数据类型
override def bufferSchema: StructType =
StructType(StructField("sum",IntegerType)::StructField("avg",IntegerType)::Nil)
// 函数的返回类型
override def dataType: DataType = DoubleType
// 表示如果有相同的输入是否存在相同的输出,如果是则true
override def deterministic: Boolean = true
// 为每个分组的数据初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0
buffer(1)=0
}
// 每个分组,有新的值进来时,如何进行分组的聚合计算,聚合的时候需要调用该方法,可以理解为map端的一个小聚合
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
// 获取这一行中的工资,然后将工资加入到sum中
buffer(0) = buffer.getInt(0) + input.getInt(0)
// 将工资的个数加1
buffer(1) = buffer.getInt(1) + 1
}
// 由于Spark是分布式的,所以一个分组的数据,可能会在不同的节点上进行局部聚合,
// 就是update,但是最后一个分组,在各节点上的聚合值,要进行Merge,也就是合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
// 合并总的工资
buffer1(0)=buffer1.getInt(0)+buffer2.getInt(0)
// 合并总的工资个数
buffer1(1)= buffer1.getInt(1) + buffer2.getInt(1)
}
// 一个分组的聚合值,如何通过中间的聚合值,最后返回一个最终的聚合值
override def evaluate(buffer: Row): Any = {
// 取出总的工资 / 总工资个数
buffer.getInt(0).toDouble/buffer.getInt(1)
}
}