函数实现逻辑
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.dataview.ListView;
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
import org.apache.flink.table.types.DataType;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
public class GroupConcatFullUDF extends BuiltInAggregateFunction<String, GroupConcatFullUDF.GroupConcatAccumulator> {
private static final String DEFAULT_SEPARATOR = ",";
@Override
public DataType getOutputDataType() {
return DataTypes.STRING();
}
@Override
public List<DataType> getArgumentDataTypes() {
List<DataType> types = new ArrayList<>(3);
types.add(DataTypes.STRING());
types.add(DataTypes.STRING());
types.add(DataTypes.BOOLEAN());
return types;
}
@Override
public DataType getAccumulatorDataType() {
return DataTypes.STRUCTURED(
GroupConcatUDF.GroupConcatAccumulator.class,
DataTypes.FIELD(
"list",
ListView.newListViewDataType(
DataTypes.STRING())),
DataTypes.FIELD(
"retractList",
ListView.newListViewDataType(
DataTypes.STRING())));
}
public void resetAccumulator(GroupConcatUDF.GroupConcatAccumulator acc) {
acc.list.clear();
acc.retractList.clear();
}
@Override
public GroupConcatAccumulator createAccumulator() {
final GroupConcatAccumulator acc = new GroupConcatAccumulator();
acc.list = new ListView<>();
acc.retractList = new ListView<>();
return acc;
}
/** 去重场景下暂时不支持回撤,后续优化 **/
public void retract(GroupConcatAccumulator acc, String value) throws Exception {
if (value != null) {
if (!acc.list.remove(value)) {
acc.retractList.add(value);
}
}
}
public void merge(
GroupConcatAccumulator acc, Iterable<GroupConcatAccumulator> its)
throws Exception {
for (GroupConcatAccumulator otherAcc : its) {
// merge list of acc and other
List<String> buffer = new ArrayList<>();
for (String binaryString : acc.list.get()) {
buffer.add(binaryString);
}
for (String binaryString : otherAcc.list.get()) {
buffer.add(binaryString);
}
// merge retract list of acc and other
List<String> retractBuffer = new ArrayList<>();
for (String binaryString : acc.retractList.get()) {
retractBuffer.add(binaryString);
}
for (String binaryString : otherAcc.retractList.get()) {
retractBuffer.add(binaryString);
}
// merge list & retract list
List<String> newRetractBuffer = new ArrayList<>();
for (String binaryString : retractBuffer) {
if (!buffer.remove(binaryString)) {
newRetractBuffer.add(binaryString);
}
}
// update to acc
acc.list.clear();
acc.list.addAll(buffer);
acc.retractList.clear();
acc.retractList.addAll(newRetractBuffer);
}
}
public void accumulate(GroupConcatAccumulator acc, String value) throws Exception {
accumulate(acc, DEFAULT_SEPARATOR, value, Boolean.FALSE);
}
public void accumulate(GroupConcatAccumulator acc, String value, boolean deduplication) throws Exception {
accumulate(acc, DEFAULT_SEPARATOR, value, deduplication);
}
public void accumulate(GroupConcatAccumulator acc, String separator, String value) throws Exception {
accumulate(acc, separator, value, Boolean.FALSE);
}
public void accumulate(GroupConcatAccumulator acc, String separator, String value, boolean deduplication) throws Exception {
if (Objects.isNull(value)) {
return;
}
if (!acc.list.getList().isEmpty()) {
/*判断重复逻辑需要更好的方式,后续优化*/
if (deduplication && acc.list.getList().contains(value)) {
return;
}
acc.list.add(separator);
}
acc.list.add(value);
}
@Override
public String getValue(GroupConcatAccumulator acc) {
if (Objects.nonNull(acc.list)) {
StringBuilder sb = new StringBuilder();
acc.list.getList().forEach(sb::append);
return sb.toString();
} else {
return null;
}
}
public static class GroupConcatAccumulator {
public ListView<String> list;
public ListView<String> retractList;
public GroupConcatAccumulator() {
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
GroupConcatAccumulator that = (GroupConcatAccumulator) o;
return Objects.equals(list, that.list) && Objects.equals(retractList, that.retractList);
}
@Override
public int hashCode() {
return Objects.hash(list, retractList);
}
}
/**分组后以分号连接数据,支持去重**/
public static class GroupConcatDedupUDF extends GroupConcatFullUDF{
@Override
public List<DataType> getArgumentDataTypes() {
List<DataType> types = new ArrayList<>(2);
types.add(DataTypes.STRING());
types.add(DataTypes.BOOLEAN());
return types;
}
}
/**分组后以指定符号连接数据,不去重**/
public static class GroupConcatSepUDF extends GroupConcatFullUDF{
@Override
public List<DataType> getArgumentDataTypes() {
List<DataType> types = new ArrayList<>(2);
types.add(DataTypes.STRING());
types.add(DataTypes.STRING());
return types;
}
}
/**分组后以分号连接数据,不去重**/
public static class GroupConcatUDF extends GroupConcatFullUDF{
@Override
public List<DataType> getArgumentDataTypes() {
return Collections.singletonList(DataTypes.STRING());
}
}
}
示例一:
CREATE TABLE print_table (
`supplier_id` STRING,
`total` STRING
) WITH (
'connector' = 'print'
);
insert into
SELECT supplier_id, GroupConcatUDF(product_id) AS total
FROM (VALUES
('supplier1', 'product1', 4),
('supplier1', 'product1', 3),
('supplier2', 'product3', 3),
('supplier2', 'product4', 4))
AS Products(supplier_id, product_id, rating)
GROUP BY supplier_id;
结果
2> +I[supplier1, product1]
12> +I[supplier2, product3]
2> -U[supplier1, product1]
2> +U[supplier1, product1,product1]
12> -U[supplier2, product3]
12> +U[supplier2, product3,product4]
示例二:
CREATE TABLE print_table (
`supplier_id` STRING,
`total` STRING
) WITH (
'connector' = 'print'
);
insert into print_table
SELECT supplier_id, GroupConcatSepUDF('&',product_id) AS total
FROM (VALUES
('supplier1', 'product1', 4),
('supplier1', 'product1', 3),
('supplier2', 'product3', 3),
('supplier2', 'product4', 4))
AS Products(supplier_id, product_id, rating)
GROUP BY supplier_id;
结果
2> +I[supplier1, product1]
12> +I[supplier2, product3]
2> -U[supplier1, product1]
2> +U[supplier1, product1&product1]
12> -U[supplier2, product3]
12> +U[supplier2, product3&product4]
示例三:
CREATE TABLE print_table (
`supplier_id` STRING,
`total` STRING
) WITH (
'connector' = 'print'
);
insert into print_table
SELECT supplier_id, GroupConcatDedupUDF(product_id,true) AS total
FROM (VALUES
('supplier1', 'product1', 4),
('supplier1', 'product1', 3),
('supplier2', 'product3', 3),
('supplier2', 'product4', 4))
AS Products(supplier_id, product_id, rating)
GROUP BY supplier_id;
结果
12> +I[supplier2, product3]
2> +I[supplier1, product1]
2> -U[supplier1, product1]
2> +U[supplier1, product1]
12> -U[supplier2, product3]
12> +U[supplier2, product3,product4]
示例四:
CREATE TABLE print_table (
`supplier_id` STRING,
`total` STRING
) WITH (
'connector' = 'print'
);
insert into print_table
SELECT supplier_id, GroupConcatFullUDF('&',product_id,true) AS total
FROM (VALUES
('supplier1', 'product1', 4),
('supplier1', 'product1', 3),
('supplier2', 'product3', 3),
('supplier2', 'product4', 4))
AS Products(supplier_id, product_id, rating)
GROUP BY supplier_id;
结果
12> +I[supplier2, product3]
2> +I[supplier1, product1]
2> -U[supplier1, product1]
2> +U[supplier1, product1]
12> -U[supplier2, product3]
12> +U[supplier2, product3&product4]