UDAF使用

UDAF使用

需要处理的数据:

{"name":"Michael", "salary":3000}
{"name":"Andy", "salary":4500}
{"name":"Justin", "salary":3500}
{"name":"Berta", "salary":4000}
  1. 读取文件
object MyUdafaa {
  def main(args: Array[String]): Unit = {
    val conf: SparkConf = new SparkConf().setAppName("udaf").setMaster("local[2]")
    val sc = new SparkContext(conf)
    val sct: SQLContext = new SQLContext(sc)

    // 读取文件
    val file: DataFrame = sct.read.json("file.txt")
    // 创建视图,并把读取的数据放进视图表里
    file.registerTempTable("student")
    // 创建函数,函数名,方法名(可以随意命名)
    sct.udf.register("udaf",StringCount)
	// 查询视图表里的所有数据
    sct.sql("select * from student").show()
/*  +-------+------+
	|   name|salary|
	+-------+------+
	|Michael|  3000|
	|   Andy|  4500|
	| Justin|  3500|
	|  Berta|  4000|
	+-------+------+  	这是打印出来的数据,还没有经过函数的调用 */

	// 这里使用函数
    sct.sql("select udaf(salary) as avga from student").show()
	/*  +------+
		|  avga|
		+------+
		|3750.0|
		+------+     这是处理后的结果*/
    sc.stop()
  }
}
  1. 继承UDAF(UserDefinedAggregateFunction)
object StringCount extends UserDefinedAggregateFunction{
  // 输入的数据类型
  override def inputSchema: StructType = 
  // 这是传入的数据类型,StructField真正的数据类型, ::Nil是创建一个list列表
  StructType(StructField("salary",IntegerType)::Nil)
  
  // 中间聚合处理时,需要处理的数据类型,中间聚合处理时,需要处理的数据类型
  override def bufferSchema: StructType =
   StructType(StructField("sum",IntegerType)::StructField("avg",IntegerType)::Nil)

	// 函数的返回类型
  override def dataType: DataType = DoubleType

  // 表示如果有相同的输入是否存在相同的输出,如果是则true
  override def deterministic: Boolean = true

  // 为每个分组的数据初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0)=0
    buffer(1)=0
  }

  //  每个分组,有新的值进来时,如何进行分组的聚合计算,聚合的时候需要调用该方法,可以理解为map端的一个小聚合
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit =  {
    // 获取这一行中的工资,然后将工资加入到sum中
    buffer(0) = buffer.getInt(0) + input.getInt(0)
    // 将工资的个数加1
    buffer(1) = buffer.getInt(1) + 1
  }

  // 由于Spark是分布式的,所以一个分组的数据,可能会在不同的节点上进行局部聚合,
  // 就是update,但是最后一个分组,在各节点上的聚合值,要进行Merge,也就是合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    // 合并总的工资
    buffer1(0)=buffer1.getInt(0)+buffer2.getInt(0)
    // 合并总的工资个数
    buffer1(1)= buffer1.getInt(1) + buffer2.getInt(1)
  }

  // 一个分组的聚合值,如何通过中间的聚合值,最后返回一个最终的聚合值
  override def evaluate(buffer: Row): Any = {
    // 取出总的工资 / 总工资个数
    buffer.getInt(0).toDouble/buffer.getInt(1)
  }
}
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值