场景
UDAF = USER DEFINED AGGREGATION FUNCTION
上一篇文章已经介绍了spark sql的窗口函数,并知道Spark sql提供了丰富的内置函数供猿友们使用,辣为何还要用户自定义函数呢?实际的业务场景可能很复杂,内置函数hold不住,所以spark sql提供了可扩展的内置函数接口:哥们,你的业务太变态了,我满足不了你,自己按照我的规范去定义一个sql函数,该怎么折腾就怎么折腾!
例如,MySQL数据库中有一张task表,共两个字段taskid (任务ID)与taskParam(JSON格式的任务请求参数)。简单起见,这里只列出一条记录:
taskid
1
taskParam
{"endAge":["50"],"endDate":["2016-06-21"],"startAge":["10"],"startDate":["2016-06-21"]}
假设应用程序已经读取了mysql中这张表的记录,并通过 DateFrame注册成了一张临时表 task。问题来了:怎么获取taskParam中startAge的第一个值呢?
sqlContext.sql("select taskid,getJsonFieldUDF(taskParm,'startAge')")
这个时候,我们就需要自定义一个UDF函数了,取名getJsonFieldUDF。Java版本的代码大致如下:
package cool.pengych.sparker.product;
import org.apache.spark.sql.api.java.UDF2;
import com.alibaba.fastjson.JSONObject;
/**
* 用户自定义函数
* @author pengyucheng
*/
public class GetJsonObjectUDF implements UDF2<String,String,String>
{
/**
* 获取数组类型json字符串中某一字段的值
*/
@Override
public String call(String json, String field) throws Exception
{
try
{
JSONObject jsonObject = JSONObject.parseObject(json);
return jsonObject.getJSONArray(field).getString(0);
}
catch(Exception e)
{
e.printStackTrace();
}
return null;
}
}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
这样的需求在实际项目中是很普遍的:请求参数经常以json格式存储在数据库中,,,完了,越写越多 。这里还是先以Scala实现一个简单的hello world级别的小样为例,来体验udf与udaf的使用好了。
问题
将如下数组:
val bigData = Array("Spark","Hadoop","Flink","Spark","Hadoop","Flink",
"Spark","Hadoop","Flink","Spark","Hadoop","Flink")
中的字符分组聚合并计算出每个字符的长度及字符出现的个数。正常结果
如下:
+------+-----+------+
| Spark| 4| 5|
| Flink| 4| 5|
注:‘spark’ 这个字符的长度为5 ,共出现了4次。
分析
- 自定义个一个求字符串长度的函数
自定义的sql函数,与scala中的普通函数一样,只不过在使用上前者需要先在sqlContext中进行注册。 - 自定义一个聚合函数
按照字符串名称分组后,调用自定义的聚合函数实现累加。
啊,好抽象,直接看代码吧!
代码
package main.scala
import org.apache.spark.SparkContext
import org.apache.spark.SparkConf
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.expressions.MutableAggregationBuffer
/**
* Spark SQL UDAS:user defined aggregation function
* UDF: 函数的输入是一条具体的数据记录,实现上讲就是普通的scala函数-只不过需要注册
* UDAF:用户自定义的聚合函数,函数本身作用于数据集合,能够在具体操作的基础上进行自定义操作
*/
object SparkSQLUDF {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setMaster("local[*]").setAppName("SparkSQLWindowFunctionOps")
val sc = new SparkContext(conf)
val hiveContext = new SQLContext(sc)
val bigData = Array("Spark","Hadoop","Flink","Spark","Hadoop","Flink","Spark","Hadoop","Flink","Spark","Hadoop","Flink")
val bigDataRDD = sc.parallelize(bigData)
val bigDataRowRDD = bigDataRDD.map(line => Row(line))
val structType = StructType(Array(StructField("name",StringType,true)))
val bigDataDF = hiveContext.createDataFrame(bigDataRowRDD, structType)
bigDataDF.registerTempTable("bigDataTable")
hiveContext.udf.register("computeLength",(input:String) => input.length)
hiveContext.sql("select name,computeLength(name) as length from bigDataTable").show
hiveContext.udf.register("wordCount",new MyUDAF)
hiveContext.sql("select name,wordCount(name) as count,computeLength(name) as length from bigDataTable group by name ").show
}
}
/**
* 用户自定义函数
*/
class MyUDAF extends UserDefinedAggregateFunction
{
/**
* 指定具体的输入数据的类型
* 自段名称随意:Users can choose names to identify the input arguments - 这里可以是“name”,或者其他任意串
*/
override def inputSchema:StructType = StructType(Array(StructField("name",StringType,true)))
/**
* 在进行聚合操作的时候所要处理的数据的中间结果类型
*/
override def bufferSchema:StructType = StructType(Array(StructField("count",IntegerType,true)))
/**
* 返回类型
*/
override def dataType:DataType = IntegerType
/**
* whether given the same input,
* always return the same output
* true: yes
*/
override def deterministic:Boolean = true
/**
* Initializes the given aggregation buffer
*/
override def initialize(buffer:MutableAggregationBuffer):Unit = {buffer(0)=0}
/**
* 在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
* 本地的聚合操作,相当于Hadoop MapReduce模型中的Combiner
*/
override def update(buffer:MutableAggregationBuffer,input:Row):Unit={
buffer(0) = buffer.getInt(0)+1
}
/**
* 最后在分布式节点进行local reduce完成后需要进行全局级别的merge操作
*/
override def merge(buffer1:MutableAggregationBuffer,buffer2:Row):Unit={
buffer1(0) = buffer1.getInt(0)+buffer2.getInt(0)
}
/**
* 返回UDAF最后的计算结果
*/
override def evaluate(buffer:Row):Any = buffer.getInt(0)
}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
执行结果
| Spark| 4| 5|
| Flink| 4| 5|
16/06/29 19:30:24 INFO DAGScheduler: Job 3 finished: show at SparkSQLUDF.scala:48, took 1.717878 s
总结
hiveContext.udf.register("computeLength",(input:String) => input.length)
hiveContext.udf.register("wordCount",new MyUDAF)