需求是将hive 代码转成spark 代码时 , 需要标准差实现的方式不同 ,hive最后是 /n , spark 的是/n-1 , 因此需要自定义标准差的UDAF , 以下是代码 ,有的判断异常值得情况没做处理 ,可以自行处理 ,因为本人在上游数据源已经做了一次处理,确保了数据的格式不会出现异
import org.apache.commons.lang.StringUtils;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.List;
/**
* Greated by HGH on 2019-09-11 .
* spark 与 hive 的标准差的计算的方式不同 ,因此使用自定义的 Stddev
*/
public class MyStddevUDAF extends UserDefinedAggregateFunction {
@Override
public StructType inputSchema() {
List<StructField> structFields = new ArrayList<StructField>();
structFields.add(DataTypes.createStructField( "num", DataTypes.StringType, true ));
return DataTypes.createStructType( structFields );
}
@Override
public DataType dataType() {
return DataTypes.DoubleType;
}
@Override
public StructType bufferSchema() {
List<StructField> structFields = new ArrayList<StructField>();
structFields.add(DataTypes.createStructField( "num1", DataTypes.IntegerType, true ));
structFields.add(DataTypes.createStructField( "num2", DataTypes.IntegerType, true ));
structFields.add(DataTypes.createStructField( "num3", DataTypes.IntegerType, true ));
structFields.add(DataTypes.createStructField( "sum", DataTypes.IntegerType, true ));
structFields.add(DataTypes.createStructField( "count", DataTypes.IntegerType, true ));
return DataTypes.createStructType( structFields );
}
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0,0);
buffer.update(1,0);
buffer.update(2,0);
buffer.update(3,0);
buffer.update(4,0);
}
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
buffer.update( 0 ,Integer.valueOf(Integer.valueOf(input.getString(0).equals("") || StringUtils.isBlank(input.getString(0)) ? "0" : input.getString(0))));
buffer.update(3 ,buffer.getInt(3)+Integer.valueOf(input.getString(0).equals("") || StringUtils.isBlank(input.getString(0)) ? "0" : input.getString(0)));
buffer.update(4 ,buffer.getInt(4)+1);
if(buffer.getInt(4) == 1){
buffer.update( 1 ,Integer.valueOf(Integer.valueOf(input.getString(0).equals("") || StringUtils.isBlank(input.getString(0)) ? "0" : input.getString(0))));
}
if(buffer.getInt(4) == 2){
buffer.update( 2 ,Integer.valueOf(Integer.valueOf(input.getString(0).equals("") || StringUtils.isBlank(input.getString(0)) ? "0" : input.getString(0))));
}
}
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
buffer1.update(0,buffer1.getInt(0) + buffer2.getInt(0));
buffer1.update(1,buffer1.getInt(1) + buffer2.getInt(1));
buffer1.update(2,buffer1.getInt(2) + buffer2.getInt(2));
buffer1.update(3,buffer1.getInt(3)+buffer2.getInt(3));
buffer1.update(4,buffer1.getInt(4)+buffer2.getInt(4));
}
@Override
public Object evaluate(Row buffer) {
double num1 = (double)buffer.getInt(0) ;
double num2 = (double)buffer.getInt(1) ;
double num3 = (double)buffer.getInt(2) ;
double sum = (double)buffer.getInt(3) ;
double count = (double)buffer.getInt(4) ;
double finalmean = sum / count ;
double result = Math.sqrt( ((num1 - finalmean) * (num1 - finalmean)
+ (num2 - finalmean) * (num2 - finalmean)
+ (num3 - finalmean) * (num3 - finalmean)) / count) ;
return result ;
}
@Override
public boolean deterministic() {
return false;
}
}