背景需求
中位数(Median)又称中值,统计学中的专有名词,是按顺序排列的一组数据中居于中间位置的数,代表一个样本、种群或概率分布中的一个数值,其可将数值集合划分为相等的上下两部分。对于有限的数集,可以通过把所有观察值高低排序后找出正中间的一个作为中位数。如果观察值有偶数个,通常取最中间的两个数值的平均数作为中位数。
准备1~7个乱序数字

奇数个数字经过排序【1,2,3,4,5,6,7】取出中间一个数字是4,则中位数是4

再增加一个数字8,依然是乱序

偶数个数字经过排序【1,2,3,4,5,6,7,8】取出中间两个数字【4,5】之和除以2则中位数是4.5

代码实现
package udf;
import com.google.common.collect.Lists;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.*;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.DoubleWritable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
/**
* @ClassName: UDAFNumericMedian
* @Description:
* @Author: xuezhouyi
* @Version: V1.0
**/
public class UDAFNumericMedian extends AbstractGenericUDAFResolver {
/* 该方法会根据sql传入的参数数据格式指定调用哪个Evaluator进行处理 */
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
if (parameters.length != 1) {
throw new UDFArgumentTypeException(parameters.length - 1, "Exactly one argument is expected.");
}
if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentTypeException(0, "Only primitive type arguments are accepted.");
}
if (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.DOUBLE) {
throw new UDFArgumentTypeException(0, "Only Double type arguments are accepted.");
}
return new DoubleEvaluator();
}
private static class DoubleEvaluator extends GenericUDAFEvaluator {
private PrimitiveObjectInspector inputOI;
private StandardListObjectInspector mergeOI;
/* 确定各个阶段输入输出参数的数据格式ObjectInspectors */
@Override
public ObjectInspector init(GenericUDAFEvaluator.Mode m, ObjectInspector[] parameters) throws HiveException {
super.init(m, parameters);
if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
inputOI = (PrimitiveObjectInspector) parameters[0];
return ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(inputOI));
} else {
if (parameters[0] instanceof StandardListObjectInspector) {
mergeOI = (StandardListObjectInspector) parameters[0];
inputOI = (PrimitiveObjectInspector) mergeOI.getListElementObjectInspector();
return ObjectInspectorUtils.getStandardObjectInspector(mergeOI);
} else {
inputOI = (PrimitiveObjectInspector) parameters[0];
return ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(inputOI));
}
}
}
/* 保存数据聚集结果的类 */
private static class MyAggBuf implements AggregationBuffer {
List<Object> container = Lists.newArrayList();
}
/* 重置聚集结果 */
@Override
public void reset(AggregationBuffer agg) throws HiveException {
((MyAggBuf) agg).container.clear();
}
/* 获取数据集结果类 */
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
MyAggBuf myAgg = new MyAggBuf();
return myAgg;
}
/* map阶段,迭代处理输入sql传过来的列数据 */
@Override
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
if (parameters == null || parameters.length != 1) {
return;
}
if (parameters[0] != null) {
MyAggBuf myAgg = (MyAggBuf) agg;
myAgg.container.add(ObjectInspectorUtils.copyToStandardObject(parameters[0], this.inputOI));
}
}
/* map与combiner结束返回结果,得到部分数据聚集结果 */
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
MyAggBuf myAgg = (MyAggBuf) agg;
List<Object> list = Lists.newArrayList(myAgg.container);
return list;
}
/* combiner合并map返回的结果,还有reducer合并mapper或combiner返回的结果 */
@Override
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
if (partial == null) {
return;
}
MyAggBuf myAgg = (MyAggBuf) agg;
List<Object> partialResult = (List<Object>) mergeOI.getList(partial);
for (Object ob : partialResult) {
myAgg.container.add(ObjectInspectorUtils.copyToStandardObject(ob, this.inputOI));
}
}
/* reducer阶段,输出最终结果 */
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
/* 排序 */
List<Object> container = ((MyAggBuf) agg).container;
LinkedList<Double> doubles = new LinkedList<>();
for (Object o : container) {
doubles.add(((DoubleWritable) o).get());
}
Collections.sort(doubles);
/* 计算中位数 */
DoubleWritable median;
int size = doubles.size();
if (size % 2 != 0) {
int i = size / 2;
median = new DoubleWritable(doubles.get(i));
} else {
int i = size / 2;
int j = size / 2 - 1;
median = new DoubleWritable((doubles.get(i) + doubles.get(j)) / 2);
}
/* 构建对象并返回 */
ArrayList<Object> objects = new ArrayList<>();
objects.add(median);
return objects;
}
}
}
分组测试

运行结果

本文介绍了在Hive中自定义UDAF函数来计算中位数的需求、实现过程及测试案例。通过中位数的概念,展示了对奇数和偶数个数值计算中位数的方法,并给出了具体的代码实现和运行结果。
799

被折叠的 条评论
为什么被折叠?



