packageio.renren.utils.udf;importlombok.AllArgsConstructor;importlombok.Data;importlombok.NoArgsConstructor;importorg.apache.spark.sql.*;importorg.apache.spark.sql.expressions.Aggregator;importorg.apache.spark.sql.expressions.UserDefinedFunction;importjava.io.Serializable;importstaticorg.apache.spark.sql.Encoders.*;/**
* @program: renren-cloud
* @description:
* @author: yyyyjinying
* @create: 2023-06-20 14:11
**/publicclassJavaUserDefinedTypedAggregation{@Data@NoArgsConstructor@AllArgsConstructorpublicstaticclassEmployeeimplementsSerializable{privateString name;privatelong salary;}@Data@NoArgsConstructor@AllArgsConstructorpublicstaticclassAverageimplementsSerializable{privatelong sum;privatelong count;}publicstaticclassMyAverageextendsAggregator<Employee,Average,Double>{@OverridepublicAveragezero(){returnnewAverage(0L,0L);}@OverridepublicAveragereduce(Average buffer,Employee employee){long newSum = buffer.getSum()+ employee.getSalary();long newCount = buffer.getCount()+1;
buffer.setSum(newSum);
buffer.setCount(newCount);return buffer;}@OverridepublicAveragemerge(Average b1,Average b2){long mergedSum = b1.getSum()+ b2.getSum();long mergedCount = b1.getCount()+ b2.getCount();
b1.setSum(mergedSum);
b1.setCount(mergedCount);return b1;}@OverridepublicDoublefinish(Average reduction){return((double) reduction.getSum())/ reduction.getCount();}@OverridepublicEncoder<Average>bufferEncoder(){returnbean(Average.class);}@OverridepublicEncoder<Double>outputEncoder(){returnDOUBLE();}}publicstaticclassMyUnAverageextendsAggregator<Long,Average,Double>{// A zero value for this aggregation. Should satisfy the property that any b + zero = b@OverridepublicAveragezero(){returnnewAverage(0L,0L);}// Combine two values to produce a new value. For performance, the function may modify `buffer`// and return it instead of constructing a new object@OverridepublicAveragereduce(Average buffer,Long data){long newSum = buffer.getSum()+ data;long newCount = buffer.getCount()+1;
buffer.setSum(newSum);
buffer.setCount(newCount);return buffer;}// Merge two intermediate values@OverridepublicAveragemerge(Average b1,Average b2){long mergedSum = b1.getSum()+ b2.getSum();long mergedCount = b1.getCount()+ b2.getCount();
b1.setSum(mergedSum);
b1.setCount(mergedCount);return b1;}// Transform the output of the reduction@OverridepublicDoublefinish(Average reduction){return((double) reduction.getSum())/ reduction.getCount();}// Specifies the Encoder for the intermediate value type@OverridepublicEncoder<Average>bufferEncoder(){returnEncoders.bean(Average.class);}// Specifies the Encoder for the final output value type@OverridepublicEncoder<Double>outputEncoder(){returnEncoders.DOUBLE();}}publicstaticSparkSessiongetSpark(){returnSparkSession.builder().appName("spark udaf example").master("local[*]").config("dfs.client.use.datanode.hostname",true).getOrCreate();}publicstaticvoidaverageTypeUdf(){SparkSession spark =getSpark();// $example on:typed_custom_aggregation$Encoder<Employee> employeeEncoder =bean(Employee.class);String path ="renren-admin\\renren-admin-server\\src\\main\\resources\\employees.json";Dataset<Employee> ds = spark.read().json(path).as(employeeEncoder);
ds.show();// +-------+------+// | name|salary|// +-------+------+// |Michael| 3000|// | Andy| 4500|// | Justin| 3500|// | Berta| 4000|// +-------+------+MyAverage myAverage =newMyAverage();// Convert the function to a `TypedColumn` and give it a nameTypedColumn<Employee,Double> averageSalary = myAverage.toColumn().name("average_salary");Dataset<Double> result = ds.select(averageSalary);
result.show();// +--------------+// |average_salary|// +--------------+// | 3750.0|// +--------------+
spark.stop();}publicstaticvoidaverageUnTypeUdf(){SparkSession spark =getSpark();String path ="renren-admin\\renren-admin-server\\src\\main\\resources\\employees.json";Dataset<Row> df = spark.read().json(path);
df.createOrReplaceTempView("employees");
df.show();
spark.udf().register("myAverage", functions.udaf(newMyUnAverage(),LONG()));
spark.sql("SELECT myAverage(salary) as average_salary FROM employees").show();}publicstaticvoidmain(String[] args){// 强类型聚合平均值// averageTypeUdf();// 不安全类型聚合平均值averageUnTypeUdf();}}