Spark学习笔记:Spark UDF/UDAF

本文介绍了Spark中的用户定义函数(UDF)和用户定义聚合函数(UDAF)。首先展示了如何使用Spark UDF,包括读取CSV数据,创建DataFrame,注册临时视图,定义和调用自定义字符串拼接函数。接着,通过一个Java实现的UDAF例子,解释了如何计算分组平均数。提供了测试数据和SQL语句的示例,演示了如何对数据进行分组并计算平均值。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

 

Spark UDF/UDAF

1.SparkUDF

2.Spark UDAF

(1)求平均数


Spark UDF/UDAF

1.SparkUDF

(1)测试数据: emp.csv文件
(2)定义case class
         case class Emp(empno:Int,ename:String,job:String,mgr:String,hiredate:String,sal:Int,comm:String,deptno:Int)
(3)导入emp.csv的文件
          val lineRDD = sc.textFile("/emp.csv").map(_.split(","))
(4)生成DataFrame
         val allEmp = lineRDD.map(x=>Emp(x(0).toInt,x(1),x(2),x(3),x(4),x(5).toInt,x(6),x(7).toInt))
         val empDF = allEmp.toDF
(5)注册成一个临时视图
         empDF.createOrReplaceTempView("emp")
(6)自定义一个函数,拼加字符串
         spark.sqlContext.udf.register("concatstr",(s1:String,s2:String)=>s1+"***"+s2)
(7)调用自定义函数,将ename和job这两个字段拼接在一起
         spark.sql("select concatstr(ename,job) from emp").show

测试结果

2.Spark UDAF

UDAF就是用户自定义聚合函数,比如平均值,最大最小值,累加,拼接等。

这里以求平均数为例,并用Java实现,以后应该会陆续更新一些自带函数的实现当练习。

(1)求平均数

package SparkUDAF;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.List;

public class MyAvg extends UserDefinedAggregateFunction {

    @Override
    public StructType inputSchema() {
        //输入数据的类型,输入的是字符串
        List<StructField> structFields = new ArrayList<>();
        structFields.add(DataTypes.createStructField("InputData", DataTypes.StringType, true));

        return DataTypes.createStructType(structFields);
    }

    @Override
    public StructType bufferSchema() {

        //聚合操作时,所处理的数据的数据类型,在这个例子里求平均数,要先求和(Sum),然后除以个数(Amount),所以这里需要处理两个字段
        //注意因为用了ArrayList,所以是有序的
        List<StructField> structFields = new ArrayList<>();
        structFields.add(DataTypes.createStructField("Amount", DataTypes.IntegerType, true));
        structFields.add(DataTypes.createStructField("Sum", DataTypes.IntegerType, true));

        return DataTypes.createStructType(structFields);
    }

    @Override
    public DataType dataType() {
        //UDAF计算后的返回值类型
        return DataTypes.IntegerType;
    }

    @Override
    public boolean deterministic() {
        //判断输入和输出的类型是否一致,如果返回的是true则表示一致,false表示不一致,自行设置
        return false;
    }

    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        /*
        对辅助字段进行初始化,就是上面定义的field1和field2
        第一个辅助字段的下标为0,初始值为0
        第二个辅助字段的下标为1,初始值为0
        */
        buffer.update(0, 0);
        buffer.update(1, 0);
    }

    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        /*
        update可以认为是在每一个节点上都会对数据执行的操作,UDAF函数执行的时候,数据会被分发到每一个节点上,就是每一个分区
		buffer.getInt(0)获取的是上一次聚合后的值,input就是当前获取的数据
		*/

        //修改辅助字段的值,buffer.getInt(x)获取的是上一次聚合后的值,x表示
        buffer.update(0, buffer.getInt(0) + 1); //表示某个数字的个数
        buffer.update(1, buffer.getInt(1) + Integer.parseInt(input.getString(0))); //表示某个数字的总和
    }

    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        /*
        merge:对每个分区的结果进行合并,每个分布式的节点上做完update之后就要做一个全局合并的操作
        合并每一个update操作的结果,将各个节点上的数据合并起来
        buffer1.getInt(0) : 上一次聚合后的值
		buffer2.getInt(0) : 这次计算传入进来的update的结果
		*/

        //对第一个字段Amount进行求和,求出总个数
        buffer1.update(0, buffer1.getInt(0) + buffer2.getInt(0));
        //对第二个字段Sum进行求和,求出总和
        buffer1.update(1, buffer1.getInt(1) + buffer2.getInt(1));
    }

    @Override
    public Object evaluate(Row buffer) {
        //表示最终计算的结果,第二个参数表示和值,第一个参数表示个数
        return buffer.getInt(1) / buffer.getInt(0);
    }
}

测试数据
数据形式key^value,^是分隔符

a^4
a^6
b^2
b^4
b^6
对第二列的值求平均值并根据第一列做分组统计,SQL写法是
select key,avg(value) from test group by key;

测试程序如下

package SparkUDAF;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.List;

public class TestMain {
    public static void main(String[] args) {
        SparkConf conf =new SparkConf();
        conf.setMaster("local").setAppName("MyAvg");
        JavaSparkContext sc= new JavaSparkContext(conf);
        //得到SQLContext对象
        SQLContext sqlContext = new SQLContext(sc);

        //注册自定义函数
        sqlContext.udf().register("my_avg",new MyAvg());

        //读入数据
        JavaRDD<String> lines = sc.textFile("d:\\test.txt");
        //分词
        JavaRDD<Row> rows=lines.map(line-> RowFactory.create(line.split("\\^")));

        //定义schema的结构,a字段是字母,b字段是value
        List<StructField> structFields = new ArrayList<>();
        structFields.add(DataTypes.createStructField("a",DataTypes.StringType,true));
        structFields.add(DataTypes.createStructField("b",DataTypes.StringType,true));
        StructType structType = DataTypes.createStructType(structFields);

        //创建DataFrame
        Dataset ds=sqlContext.createDataFrame(rows,structType);
        ds.registerTempTable("test");

        //执行查询
        sqlContext.sql("select a,my_avg(b) from test group by a").show();
        sc.stop();
    }
}

测试结果

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值