2.3 SQL语法的用户自定义函数
2.3.1 UDF
1)UDF:一行进入,一行出
2)代码实现
package com.atguigu.sparksql;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.types.DataTypes;
import static org.apache.spark.sql.functions.udf;
public class Test04_UDF {
public static void main(String[] args) {
//1. 创建配置对象
SparkConf conf = new SparkConf().setAppName("sparksql").setMaster("local[*]");
//2. 获取sparkSession
SparkSession spark = SparkSession.builder().config(conf).getOrCreate();
//3. 编写代码
Dataset<Row> lineRDD = spark.read().json("input/user.json");
lineRDD.createOrReplaceTempView("user");
// 定义一个函数
// 需要首先导入依赖import static org.apache.spark.sql.functions.udf;
UserDefinedFunction addName = udf(new UDF1<String, String>() {
@Override
public String call(String s) throws Exception {
return s + " 大侠";
}
}, DataTypes.StringType);
spark.udf().register("addName",addName);
spark.sql("select addName(name) newName from user")
.show();
// lambda表达式写法
spark.udf().register("addName1",(UDF1<String,String>) name -> name + " 大侠",DataTypes.StringType);
//4. 关闭sparkSession
spark.close();
}
}
2.3.2 UDAF
1)UDAF:输入多行,返回一行。通常和groupBy一起使用,如果直接使用UDAF函数,默认将所有的数据合并在一起。
2)Spark3.x推荐使用extends Aggregator自定义UDAF,属于强类型的Dataset方式。
3)Spark2.x使用extends UserDefinedAggregateFunction,属于弱类型的DataFrame
4)案例实操
需求:实现求平均年龄,自定义UDAF,MyAvg(age)
(1)自定义聚合函数实现-强类型
package com.atguigu.sparksql;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.Aggregator;
import java.io.Serializable;
import static org.apache.spark.sql.functions.udaf;
public class Test05_UDAF {
public static void main(String[] args) {
//1. 创建配置对象
SparkConf conf = new SparkConf().setAppName("sparksql").setMaster("local[*]");
//2. 获取sparkSession
SparkSession spark = SparkSession.builder().config(conf).getOrCreate();
//3. 编写代码
spark.read().json("input/user.json").createOrReplaceTempView("user");
// 注册需要导入依赖 import static org.apache.spark.sql.functions.udaf;
spark.udf().register("avgAge",udaf(new MyAvg(),Encoders.LONG()));
spark.sql("select avgAge(age) newAge from user").show();
//4. 关闭sparkSession
spark.close();
}
public static class Buffer implements Serializable {
private Long sum;
private Long count;
public Buffer() {
}
public Buffer(Long sum, Long count) {
this.sum = sum;
this.count = count;
}
public Long getSum() {
return sum;
}
public void setSum(Long sum) {
this.sum = sum;
}
public Long getCount() {
return count;
}
public void setCount(Long count) {
this.count = count;
}
}
public static class MyAvg extends Aggregator<Long,Buffer,Double>{
@Override
public Buffer zero() {
return new Buffer(0L,0L);
}
@Override
public Buffer reduce(Buffer b, Long a) {
b.setSum(b.getSum() + a);
b.setCount(b.getCount() + 1);
return b;
}
@Override
public Buffer merge(Buffer b1, Buffer b2) {
b1.setSum(b1.getSum() + b2.getSum());
b1.setCount(b1.getCount() + b2.getCount());
return b1;
}
@Override
public Double finish(Buffer reduction) {
return reduction.getSum().doubleValue() / reduction.getCount();
}
@Override
public Encoder<Buffer> bufferEncoder() {
// 可以用kryo进行优化
return Encoders.kryo(Buffer.class);
}
@Override
public Encoder<Double> outputEncoder() {
return Encoders.DOUBLE();
}
}
}
2.3.3 UDTF(没有)
输入一行,返回多行(Hive)。
SparkSQL中没有UDTF,需要使用算子类型的flatMap先完成拆分。
755

被折叠的 条评论
为什么被折叠?



