Hive自定义聚合函数UDAF(计算中位数)

本文介绍了在Hive中自定义UDAF函数来计算中位数的需求、实现过程及测试案例。通过中位数的概念,展示了对奇数和偶数个数值计算中位数的方法,并给出了具体的代码实现和运行结果。

背景需求

中位数(Median)又称中值,统计学中的专有名词,是按顺序排列的一组数据中居于中间位置的数,代表一个样本、种群或概率分布中的一个数值,其可将数值集合划分为相等的上下两部分。对于有限的数集,可以通过把所有观察值高低排序后找出正中间的一个作为中位数。如果观察值有偶数个,通常取最中间的两个数值的平均数作为中位数。

准备1~7个乱序数字

奇数个数字经过排序【1,2,3,4,5,6,7】取出中间一个数字是4,则中位数是4

再增加一个数字8,依然是乱序

偶数个数字经过排序【1,2,3,4,5,6,7,8】取出中间两个数字【4,5】之和除以2则中位数是4.5

 

代码实现

package udf;

import com.google.common.collect.Lists;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.*;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.DoubleWritable;

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;

/**
 * @ClassName: UDAFNumericMedian
 * @Description: 
 * @Author: xuezhouyi
 * @Version: V1.0
 **/
public class UDAFNumericMedian extends AbstractGenericUDAFResolver {

	/* 该方法会根据sql传入的参数数据格式指定调用哪个Evaluator进行处理 */
	@Override
	public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
		if (parameters.length != 1) {
			throw new UDFArgumentTypeException(parameters.length - 1, "Exactly one argument is expected.");
		}
		if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
			throw new UDFArgumentTypeException(0, "Only primitive type arguments are accepted.");
		}
		if (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.DOUBLE) {
			throw new UDFArgumentTypeException(0, "Only Double type arguments are accepted.");
		}
		return new DoubleEvaluator();
	}

	private static class DoubleEvaluator extends GenericUDAFEvaluator {
		private PrimitiveObjectInspector inputOI;
		private StandardListObjectInspector mergeOI;

		/* 确定各个阶段输入输出参数的数据格式ObjectInspectors */
		@Override
		public ObjectInspector init(GenericUDAFEvaluator.Mode m, ObjectInspector[] parameters) throws HiveException {
			super.init(m, parameters);
			if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
				inputOI = (PrimitiveObjectInspector) parameters[0];
				return ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(inputOI));
			} else {
				if (parameters[0] instanceof StandardListObjectInspector) {
					mergeOI = (StandardListObjectInspector) parameters[0];
					inputOI = (PrimitiveObjectInspector) mergeOI.getListElementObjectInspector();
					return ObjectInspectorUtils.getStandardObjectInspector(mergeOI);
				} else {
					inputOI = (PrimitiveObjectInspector) parameters[0];
					return ObjectInspectorFactory.getStandardListObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(inputOI));
				}
			}
		}

		/* 保存数据聚集结果的类 */
		private static class MyAggBuf implements AggregationBuffer {
			List<Object> container = Lists.newArrayList();
		}

		/* 重置聚集结果 */
		@Override
		public void reset(AggregationBuffer agg) throws HiveException {
			((MyAggBuf) agg).container.clear();
		}

		/* 获取数据集结果类 */
		@Override
		public AggregationBuffer getNewAggregationBuffer() throws HiveException {
			MyAggBuf myAgg = new MyAggBuf();
			return myAgg;
		}

		/* map阶段,迭代处理输入sql传过来的列数据 */
		@Override
		public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
			if (parameters == null || parameters.length != 1) {
				return;
			}
			if (parameters[0] != null) {
				MyAggBuf myAgg = (MyAggBuf) agg;
				myAgg.container.add(ObjectInspectorUtils.copyToStandardObject(parameters[0], this.inputOI));
			}
		}

		/* map与combiner结束返回结果,得到部分数据聚集结果 */
		@Override
		public Object terminatePartial(AggregationBuffer agg) throws HiveException {
			MyAggBuf myAgg = (MyAggBuf) agg;
			List<Object> list = Lists.newArrayList(myAgg.container);
			return list;
		}

		/* combiner合并map返回的结果,还有reducer合并mapper或combiner返回的结果 */
		@Override
		public void merge(AggregationBuffer agg, Object partial) throws HiveException {
			if (partial == null) {
				return;
			}
			MyAggBuf myAgg = (MyAggBuf) agg;
			List<Object> partialResult = (List<Object>) mergeOI.getList(partial);
			for (Object ob : partialResult) {
				myAgg.container.add(ObjectInspectorUtils.copyToStandardObject(ob, this.inputOI));
			}
		}

		/* reducer阶段,输出最终结果 */
		@Override
		public Object terminate(AggregationBuffer agg) throws HiveException {
			/* 排序 */
			List<Object> container = ((MyAggBuf) agg).container;
			LinkedList<Double> doubles = new LinkedList<>();
			for (Object o : container) {
				doubles.add(((DoubleWritable) o).get());
			}
			Collections.sort(doubles);

			/* 计算中位数 */
			DoubleWritable median;
			int size = doubles.size();
			if (size % 2 != 0) {
				int i = size / 2;
				median = new DoubleWritable(doubles.get(i));
			} else {
				int i = size / 2;
				int j = size / 2 - 1;
				median = new DoubleWritable((doubles.get(i) + doubles.get(j)) / 2);
			}

			/* 构建对象并返回 */
			ArrayList<Object> objects = new ArrayList<>();
			objects.add(median);
			return objects;
		}
	}
}

 

分组测试

 

运行结果

Hive 自定义函数主要分为 UDF(普通函数,一进一出)、UDAF聚合函数,多进一出)和 UDTF(表生成函数,一进多出),其应用场景如下: ### 数据格式转换 当需要将数据从一种格式转换为另一种格式,而 Hive 内置函数无法满足时,可使用自定义 UDF。例如,将日期字符串转换为特定格式,或者对复杂的 JSON 字符串进行解析。 ```python # 示例:自定义 UDF 实现日期格式转换 @udf(returnType=StringType()) def convert_date(date_str): try: # 假设输入日期格式为 'yyyy-MM-dd',转换为 'dd/MM/yyyy' return datetime.strptime(date_str, '%Y-%m-%d').strftime('%d/%m/%Y') except ValueError: return None ``` ### 复杂业务逻辑处理 对于复杂的业务规则和计算逻辑,内置函数难以实现,此时可以编写自定义函数。比如根据用户的行为数据计算用户的活跃度得分,或者根据商品的销售数据计算销售趋势。 ```python # 示例:自定义 UDF 计算用户活跃度得分 @udf(returnType=IntegerType()) def calculate_activity_score(login_count, purchase_count): # 简单示例,根据登录次数和购买次数计算活跃度得分 return login_count * 2 + purchase_count * 3 ``` ### 数据清洗和预处理 在数据进入分析流程之前,需要进行清洗和预处理,去除无效数据、处理缺失值等。自定义 UDF 可以根据具体的数据情况实现定制化的数据清洗逻辑。 ```python # 示例:自定义 UDF 处理缺失值 @udf(returnType=StringType()) def handle_missing_value(value): if value is None or value == '': return 'Unknown' return value ``` ### 自定义聚合操作 当 Hive 内置的聚合函数(如 SUM、COUNT、AVG 等)无法满足特定的聚合时,可以使用自定义 UDAF。例如,计算加权平均值、中位数等。 ```python # 示例:自定义 UDAF 计算加权平均值 class WeightedAverageUDAF(GenericUDAFResolver): def getEvaluator(self, _): return WeightedAverageEvaluator() class WeightedAverageEvaluator(UDAFEvaluator): def init(self): self.sum_weighted_values = 0 self.sum_weights = 0 def iterate(self, value, weight): if value is not None and weight is not None: self.sum_weighted_values += value * weight self.sum_weights += weight return True def terminatePartial(self): return (self.sum_weighted_values, self.sum_weights) def merge(self, partial): if partial is not None: self.sum_weighted_values += partial[0] self.sum_weights += partial[1] return True def terminate(self): if self.sum_weights == 0: return None return self.sum_weighted_values / self.sum_weights ``` ### 数据展开和拆分 对于包含数组或复杂数据结构的字段,需要将其展开为多行记录,或者进行拆分处理,可使用自定义 UDTF。例如,将包含多个标签的字符串拆分为多个单独的标签记录。 ```python # 示例:自定义 UDTF 拆分字符串 class SplitTagsUDTF(GenericUDTF): def initialize(self, arg_types): return [StringType()] def process(self, tags_str): if tags_str is not None: tags = tags_str.split(',') for tag in tags: self.forward([tag]) def close(self): pass ```
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值