目录
Spark UDF/UDAF
1.SparkUDF
(1)测试数据: emp.csv文件
(2)定义case class
case class Emp(empno:Int,ename:String,job:String,mgr:String,hiredate:String,sal:Int,comm:String,deptno:Int)
(3)导入emp.csv的文件
val lineRDD = sc.textFile("/emp.csv").map(_.split(","))
(4)生成DataFrame
val allEmp = lineRDD.map(x=>Emp(x(0).toInt,x(1),x(2),x(3),x(4),x(5).toInt,x(6),x(7).toInt))
val empDF = allEmp.toDF
(5)注册成一个临时视图
empDF.createOrReplaceTempView("emp")
(6)自定义一个函数,拼加字符串
spark.sqlContext.udf.register("concatstr",(s1:String,s2:String)=>s1+"***"+s2)
(7)调用自定义函数,将ename和job这两个字段拼接在一起
spark.sql("select concatstr(ename,job) from emp").show
测试结果
2.Spark UDAF
UDAF就是用户自定义聚合函数,比如平均值,最大最小值,累加,拼接等。
这里以求平均数为例,并用Java实现,以后应该会陆续更新一些自带函数的实现当练习。
(1)求平均数
package SparkUDAF;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.List;
public class MyAvg extends UserDefinedAggregateFunction {
@Override
public StructType inputSchema() {
//输入数据的类型,输入的是字符串
List<StructField> structFields = new ArrayList<>();
structFields.add(DataTypes.createStructField("InputData", DataTypes.StringType, true));
return DataTypes.createStructType(structFields);
}
@Override
public StructType bufferSchema() {
//聚合操作时,所处理的数据的数据类型,在这个例子里求平均数,要先求和(Sum),然后除以个数(Amount),所以这里需要处理两个字段
//注意因为用了ArrayList,所以是有序的
List<StructField> structFields = new ArrayList<>();
structFields.add(DataTypes.createStructField("Amount", DataTypes.IntegerType, true));
structFields.add(DataTypes.createStructField("Sum", DataTypes.IntegerType, true));
return DataTypes.createStructType(structFields);
}
@Override
public DataType dataType() {
//UDAF计算后的返回值类型
return DataTypes.IntegerType;
}
@Override
public boolean deterministic() {
//判断输入和输出的类型是否一致,如果返回的是true则表示一致,false表示不一致,自行设置
return false;
}
@Override
public void initialize(MutableAggregationBuffer buffer) {
/*
对辅助字段进行初始化,就是上面定义的field1和field2
第一个辅助字段的下标为0,初始值为0
第二个辅助字段的下标为1,初始值为0
*/
buffer.update(0, 0);
buffer.update(1, 0);
}
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
/*
update可以认为是在每一个节点上都会对数据执行的操作,UDAF函数执行的时候,数据会被分发到每一个节点上,就是每一个分区
buffer.getInt(0)获取的是上一次聚合后的值,input就是当前获取的数据
*/
//修改辅助字段的值,buffer.getInt(x)获取的是上一次聚合后的值,x表示
buffer.update(0, buffer.getInt(0) + 1); //表示某个数字的个数
buffer.update(1, buffer.getInt(1) + Integer.parseInt(input.getString(0))); //表示某个数字的总和
}
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
/*
merge:对每个分区的结果进行合并,每个分布式的节点上做完update之后就要做一个全局合并的操作
合并每一个update操作的结果,将各个节点上的数据合并起来
buffer1.getInt(0) : 上一次聚合后的值
buffer2.getInt(0) : 这次计算传入进来的update的结果
*/
//对第一个字段Amount进行求和,求出总个数
buffer1.update(0, buffer1.getInt(0) + buffer2.getInt(0));
//对第二个字段Sum进行求和,求出总和
buffer1.update(1, buffer1.getInt(1) + buffer2.getInt(1));
}
@Override
public Object evaluate(Row buffer) {
//表示最终计算的结果,第二个参数表示和值,第一个参数表示个数
return buffer.getInt(1) / buffer.getInt(0);
}
}
测试数据
数据形式key^value,^是分隔符
a^4
a^6
b^2
b^4
b^6
对第二列的值求平均值并根据第一列做分组统计,SQL写法是
select key,avg(value) from test group by key;
测试程序如下
package SparkUDAF;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.List;
public class TestMain {
public static void main(String[] args) {
SparkConf conf =new SparkConf();
conf.setMaster("local").setAppName("MyAvg");
JavaSparkContext sc= new JavaSparkContext(conf);
//得到SQLContext对象
SQLContext sqlContext = new SQLContext(sc);
//注册自定义函数
sqlContext.udf().register("my_avg",new MyAvg());
//读入数据
JavaRDD<String> lines = sc.textFile("d:\\test.txt");
//分词
JavaRDD<Row> rows=lines.map(line-> RowFactory.create(line.split("\\^")));
//定义schema的结构,a字段是字母,b字段是value
List<StructField> structFields = new ArrayList<>();
structFields.add(DataTypes.createStructField("a",DataTypes.StringType,true));
structFields.add(DataTypes.createStructField("b",DataTypes.StringType,true));
StructType structType = DataTypes.createStructType(structFields);
//创建DataFrame
Dataset ds=sqlContext.createDataFrame(rows,structType);
ds.registerTempTable("test");
//执行查询
sqlContext.sql("select a,my_avg(b) from test group by a").show();
sc.stop();
}
}
测试结果