hive UDAF 求轨迹
需求 :需要求用户的行动轨迹,例如:
uid | location_id | time |
---|---|---|
001 | 1 | 2021-10-01 12:21:21 |
001 | 2 | 2021-10-01 12:22:21 |
002 | 1 | 2021-10-01 09:23:21 |
001 | 1 | 2021-10-01 11:24:21 |
001 | 1 | 2021-10-01 12:25:21 |
001 | 2 | 2021-10-01 12:26:21 |
001 | 2 | 2021-10-01 01:26:21 |
003 | 2 | 2021-10-01 12:26:21 |
003 | 2 | 2021-10-01 12:27:21 |
结果
st location_id | uid | et | |
---|---|---|---|
2021-10-01 01:26:21 | 2 | 001 | 2021-10-01 01:26:21 |
2021-10-01 11:24:21 | 1 | 001 | 2021-10-01 12:21:21 |
2021-10-01 12:22:21 | 2 | 001 | 2021-10-01 12:22:21 |
2021-10-01 12:25:21 | 1 | 001 | 2021-10-01 12:25:21 |
2021-10-01 12:26:21 | 2 | 001 | 2021-10-01 12:26:21 |
2021-10-01 09:23:21 | 1 | 002 | 2021-10-01 09:23:21 |
2021-10-01 12:26:21 | 2 | 003 | 2021-10-01 12:27:21 |
解析:
首先需要按照uid分组,组内需要按照时间进行排序,和上一条的数据进行对比,如果location_id不相等,那么就需要将上面的所有记录进行合并,st,et分别代表的是开始时间和结束时间。
注意:
需要注意的是,分组类进行操作的数据结构式List,分组内所有的数据都将在reduce端进行汇总,所以要考虑数据量的问题,一个分组内的数据不能太大,基本上轨迹信息,一个uid对应的数据不会很多,所以可以通过udaf函数来进行解决。
代码
pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.example</groupId>
<artifactId>udf</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.hive</groupId>
<artifactId>hive-exec</artifactId>
<version>3.1.2</version>
<scope>provided</scope>
</dependency>
</dependencies>
</project>
package com.yzz.udaf;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.Text;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
/**
* @author yzz
* @time 2021/10/16 12:19
* @E-mail yzzstyle@163.com
* @since 0.0.1
*/
@Description(name = "track", value = "_FUNC_(x,y,z,...) - 求轨迹")
public class TrackUDAF extends AbstractGenericUDAFResolver {
/**
* 参数定义
* <p>
* 时间 location value
*
* @param parameters
* @return
* @throws SemanticException
*/
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
if (parameters.length < 2) {
throw new SemanticException("参数必须不小于2 参数1:代表可供比较的字段例如轨迹中的点位id,参数2:代表时间,参数n:需要保留的字段");
}
List<String> fieldsNames = new ArrayList<>(parameters.length);
for (TypeInfo typeInfo : parameters) {
fieldsNames.add(typeInfo.getQualifiedName());
}
return new TrackUDAFEvaluator();
}
public static class TrackBuffer implements GenericUDAFEvaluator.AggregationBuffer {
List<Object> data = new ArrayList<>();
List<Object> merge() {
List<Object> newList = new ArrayList<>();
StringBuilder first = null;
String lastV1 = null;
String lastV2 = null;
for (Object obj : data) {
Text txt = (Text) obj;
String s = txt.toString();
String[] split = s.split("\\|");
String v1 = split[0];
String v2 = split[1];
if (null == first) {
first = new StringBuilder(s);
} else {
if (!v2.equals(lastV2)) {
first.append("|").append(lastV1);
newList.add(first.toString());
first = new StringBuilder(s);
}
}
lastV1 = v1;
lastV2 = v2;
}
assert first != null;
first.append("|").append(lastV1);
newList.add(first.toString());
return newList;
}
}
/**
* MODEL
* <p>
* PARTIAL1
* map阶段
* 调用 iterate() 、terminatePartial()
* <p>
* PARTIAL2
* 相当于 combine 阶段
* 调用 merge() and terminatePartial()
* <p>
* FINAL
* 相当于reduce阶段
* merge() and terminate()
* <p>
* COMPLETE
* iterate() and terminate()
* 相当于没有reduce,直接是map输出
*/
public static class TrackUDAFEvaluator extends GenericUDAFEvaluator {
/**
* map阶段输入的ObjectInspector
*/
private ObjectInspector[] MAP_OR_COMPLETE_OIS;
private ListObjectInspector PARTIAL2_OR_FINAL_OIS;
private Mode mode;
public TrackUDAFEvaluator() {
}
/**
* PARTIAL1 输入 location_id time value
*
* @param m
* @param parameters
* @return
* @throws HiveException
*/
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
super.init(m, parameters);
//INPUT
if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
//如果是 map阶段或者只有map阶段
MAP_OR_COMPLETE_OIS = parameters;
} else {
//剩下就是 PARTIAL2 和 FINAL
PARTIAL2_OR_FINAL_OIS = (ListObjectInspector) parameters[0];
}
//OUTPUT
return ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.javaStringObjectInspector);
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
return new TrackBuffer();
}
@Override
public void reset(AggregationBuffer agg) throws HiveException {
TrackBuffer trackBuffer = (TrackBuffer) agg;
trackBuffer.data.clear();
}
@Override
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
TrackBuffer trackBuffer = (TrackBuffer) agg;
StringBuilder sb = new StringBuilder();
for (int i = 0, len = parameters.length; i < len; i++) {
String data = PrimitiveObjectInspectorUtils.getString(parameters[i], (PrimitiveObjectInspector) MAP_OR_COMPLETE_OIS[i]);
sb.append(data);
if (i != len - 1) {
sb.append("|");
}
}
trackBuffer.data.add(sb.toString());
}
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
TrackBuffer trackBuffer = (TrackBuffer) agg;
return trackBuffer.data;
}
/**
* 相同的key都会在此
*
* @param agg
* @param partial
* @throws HiveException
*/
@Override
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
TrackBuffer trackBuffer = (TrackBuffer) agg;
trackBuffer.data.addAll(PARTIAL2_OR_FINAL_OIS.getList(partial));
}
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
//reduce
TrackBuffer trackBuffer = (TrackBuffer) agg;
trackBuffer.data.sort(new Comparator<Object>() {
@Override
public int compare(Object o1, Object o2) {
Text s1 = (Text) o1;
Text s2 = (Text) o2;
return s1.compareTo(s2);
}
});
return trackBuffer.merge();
}
}
}
打包 上传 导入hive
- 打包
- 上传至hdfs
- 导入hive add jar hdfs://nameservice/xxx
- 创建零时函数 create temporary function track as “com.yzz.udaf.TrackUDAF”;
- 执行 select track(data_time,location_id,tag) from test.dw group by tag;
参考
https://blog.youkuaiyun.com/weixin_39469127/article/details/89766266
https://blog.youkuaiyun.com/kent7306/article/details/50110067