Spark自定义函数

简单形式

自定义函数直接使用下面的形式引用一个变量或者方法都可以

spark.udf.register[String,String]("avg_get",func =>func.toString.concat("内容"))

自定义聚合函数

定义一个类,实现抽象类方法,然后通过SparkContext注册函数,sql中直接调用


  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.ERROR)
    val spark = SparkSession.builder().appName("用户自定义函数").master("local[*]").getOrCreate()
    Logger.getLogger("org").setLevel(Level.ERROR)
    val frame: DataFrame = spark.read.option("inferSchema",true).option("header", true).csv("data\\saledata.csv")
    import spark.implicits._
    import org.apache.spark.sql.functions._

    frame.createTempView("v_udfdata")
    spark.udf.register("avg_get",new MyFunctionsDefind)
    spark.sql(
      """
        |select sid,avg_get(money) from v_udfdata group by sid
        |""".stripMargin).show()
  }
}
class MyFunctionsDefind extends UserDefinedAggregateFunction{
  //指明用户输入的类型
  override def inputSchema: StructType = StructType(List(StructField("in",DataTypes.DoubleType)))

  override def bufferSchema: StructType = StructType(List(StructField("total",DataTypes.DoubleType),StructField("ammont",DataTypes.IntegerType)))

  override def dataType: DataType = DataTypes.DoubleType

  override def deterministic: Boolean =false

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer(0)=0.0
    buffer(1)=0
  }

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer(1) = buffer.getInt(1)+1
    buffer(0) = buffer.getDouble(0)+input.getDouble(0)
  }

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit ={
    //不同的分区之间如果进行汇总
    buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0) //薪水
    buffer1(1) = buffer1.getInt(1) + 1  //人数加一
  }

  override def evaluate(buffer: Row): Any =buffer.getDouble(0) / buffer.getInt(1)

Spark3.0简化版本

  def main(args: Array[String]): Unit = {
      Logger.getLogger("org").setLevel(Level.ERROR)
      val spark = SparkSession.builder().appName("用户自定义函数").master("local[*]").getOrCreate()
      Logger.getLogger("org").setLevel(Level.ERROR)
      val frame: DataFrame = spark.read.option("inferSchema",true).option("header", true).csv("data\\saledata.csv")
      import spark.implicits._
      import org.apache.spark.sql.functions._

    val avgvl=new Aggregator[Int,(Int,Double),Double] {
      override def zero: (Int, Double) = (0,0.0)

      override def reduce(b: (Int, Double), a: Int): (Int, Double) = (b._1+1,b._2+a.toDouble)

      override def merge(b1: (Int, Double), b2: (Int, Double)): (Int, Double) = (b1._1+b2._1,b1._2+b2._2)

      override def finish(reduction: (Int, Double)): Double = reduction._2/reduction._1

      override def bufferEncoder: Encoder[(Int, Double)] = Encoders.tuple(Encoders.scalaInt,Encoders.scalaDouble)

      override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
    }
      frame.createTempView("v_udfdata")
      spark.udf.register("avg_get",udaf(avgvl))

    spark.sql(
      """
        |select sid,avg_get(money) from v_udfdata group by sid
        |""".stripMargin).show()


  }

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值