Spark-sql自定义函数

2.3 SQL语法的用户自定义函数

2.3.1 UDF

1)UDF:一行进入,一行出

2)代码实现

package com.atguigu.sparksql;

import org.apache.spark.SparkConf;

import org.apache.spark.sql.Dataset;

import org.apache.spark.sql.Row;

import org.apache.spark.sql.SparkSession;

import org.apache.spark.sql.api.java.UDF1;

import org.apache.spark.sql.expressions.UserDefinedFunction;

import org.apache.spark.sql.types.DataTypes;

import static org.apache.spark.sql.functions.udf;

public class Test04_UDF {

    public static void main(String[] args) {

        //1. 创建配置对象

        SparkConf conf = new SparkConf().setAppName("sparksql").setMaster("local[*]");

        //2. 获取sparkSession

        SparkSession spark = SparkSession.builder().config(conf).getOrCreate();

        //3. 编写代码

        Dataset<Row> lineRDD = spark.read().json("input/user.json");

        lineRDD.createOrReplaceTempView("user");

        // 定义一个函数

        // 需要首先导入依赖import static org.apache.spark.sql.functions.udf;

        UserDefinedFunction addName = udf(new UDF1<String, String>() {

            @Override

            public String call(String s) throws Exception {

                return s + " 大侠";

            }

        }, DataTypes.StringType);

        spark.udf().register("addName",addName);

        spark.sql("select addName(name) newName from user")

                .show();

        // lambda表达式写法

        spark.udf().register("addName1",(UDF1<String,String>) name -> name + " 大侠",DataTypes.StringType);

        //4. 关闭sparkSession

        spark.close();

    }

}

2.3.2 UDAF

1)UDAF:输入多行,返回一行。通常和groupBy一起使用,如果直接使用UDAF函数,默认将所有的数据合并在一起。

2)Spark3.x推荐使用extends Aggregator自定义UDAF,属于强类型的Dataset方式。

3)Spark2.x使用extends UserDefinedAggregateFunction,属于弱类型的DataFrame

4)案例实操

需求:实现求平均年龄,自定义UDAF,MyAvg(age)

(1)自定义聚合函数实现-强类型

package com.atguigu.sparksql;

import org.apache.spark.SparkConf;

import org.apache.spark.sql.Encoder;

import org.apache.spark.sql.Encoders;

import org.apache.spark.sql.SparkSession;

import org.apache.spark.sql.expressions.Aggregator;

import java.io.Serializable;

import static org.apache.spark.sql.functions.udaf;

public class Test05_UDAF {

    public static void main(String[] args) {

        //1. 创建配置对象

        SparkConf conf = new SparkConf().setAppName("sparksql").setMaster("local[*]");

        //2. 获取sparkSession

        SparkSession spark = SparkSession.builder().config(conf).getOrCreate();

        //3. 编写代码

        spark.read().json("input/user.json").createOrReplaceTempView("user");

        // 注册需要导入依赖 import static org.apache.spark.sql.functions.udaf;

        spark.udf().register("avgAge",udaf(new MyAvg(),Encoders.LONG()));

        spark.sql("select avgAge(age) newAge from user").show();

        //4. 关闭sparkSession

        spark.close();

    }

    public static class Buffer implements Serializable {

        private Long sum;

        private Long count;

        public Buffer() {

        }

        public Buffer(Long sum, Long count) {

            this.sum = sum;

            this.count = count;

        }

        public Long getSum() {

            return sum;

        }

        public void setSum(Long sum) {

            this.sum = sum;

        }

        public Long getCount() {

            return count;

        }

        public void setCount(Long count) {

            this.count = count;

        }

    }

    public static class MyAvg extends Aggregator<Long,Buffer,Double>{

        @Override

        public Buffer zero() {

            return new Buffer(0L,0L);

        }

        @Override

        public Buffer reduce(Buffer b, Long a) {

            b.setSum(b.getSum() + a);

            b.setCount(b.getCount() + 1);

            return b;

        }

        @Override

        public Buffer merge(Buffer b1, Buffer b2) {

            b1.setSum(b1.getSum() + b2.getSum());

            b1.setCount(b1.getCount() + b2.getCount());

            return b1;

        }

        @Override

        public Double finish(Buffer reduction) {

            return reduction.getSum().doubleValue() / reduction.getCount();

        }

        @Override

        public Encoder<Buffer> bufferEncoder() {

            // 可以用kryo进行优化

            return Encoders.kryo(Buffer.class);

        }

        @Override

        public Encoder<Double> outputEncoder() {

            return Encoders.DOUBLE();

        }

    }

}

2.3.3 UDTF(没有)

输入一行,返回多行(Hive)。

SparkSQL中没有UDTF,需要使用算子类型的flatMap先完成拆分。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

走过冬季

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值