简单形式
自定义函数直接使用下面的形式引用一个变量或者方法都可以
spark.udf.register[String,String]("avg_get",func =>func.toString.concat("内容"))
自定义聚合函数
定义一个类,实现抽象类方法,然后通过SparkContext注册函数,sql中直接调用
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.ERROR)
val spark = SparkSession.builder().appName("用户自定义函数").master("local[*]").getOrCreate()
Logger.getLogger("org").setLevel(Level.ERROR)
val frame: DataFrame = spark.read.option("inferSchema",true).option("header", true).csv("data\\saledata.csv")
import spark.implicits._
import org.apache.spark.sql.functions._
frame.createTempView("v_udfdata")
spark.udf.register("avg_get",new MyFunctionsDefind)
spark.sql(
"""
|select sid,avg_get(money) from v_udfdata group by sid
|""".stripMargin).show()
}
}
class MyFunctionsDefind extends UserDefinedAggregateFunction{
//指明用户输入的类型
override def inputSchema: StructType = StructType(List(StructField("in",DataTypes.DoubleType)))
override def bufferSchema: StructType = StructType(List(StructField("total",DataTypes.DoubleType),StructField("ammont",DataTypes.IntegerType)))
override def dataType: DataType = DataTypes.DoubleType
override def deterministic: Boolean =false
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0.0
buffer(1)=0
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(1) = buffer.getInt(1)+1
buffer(0) = buffer.getDouble(0)+input.getDouble(0)
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit ={
//不同的分区之间如果进行汇总
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0) //薪水
buffer1(1) = buffer1.getInt(1) + 1 //人数加一
}
override def evaluate(buffer: Row): Any =buffer.getDouble(0) / buffer.getInt(1)
Spark3.0简化版本
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.ERROR)
val spark = SparkSession.builder().appName("用户自定义函数").master("local[*]").getOrCreate()
Logger.getLogger("org").setLevel(Level.ERROR)
val frame: DataFrame = spark.read.option("inferSchema",true).option("header", true).csv("data\\saledata.csv")
import spark.implicits._
import org.apache.spark.sql.functions._
val avgvl=new Aggregator[Int,(Int,Double),Double] {
override def zero: (Int, Double) = (0,0.0)
override def reduce(b: (Int, Double), a: Int): (Int, Double) = (b._1+1,b._2+a.toDouble)
override def merge(b1: (Int, Double), b2: (Int, Double)): (Int, Double) = (b1._1+b2._1,b1._2+b2._2)
override def finish(reduction: (Int, Double)): Double = reduction._2/reduction._1
override def bufferEncoder: Encoder[(Int, Double)] = Encoders.tuple(Encoders.scalaInt,Encoders.scalaDouble)
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
frame.createTempView("v_udfdata")
spark.udf.register("avg_get",udaf(avgvl))
spark.sql(
"""
|select sid,avg_get(money) from v_udfdata group by sid
|""".stripMargin).show()
}