Flink 自定义GroupConcat函数

函数实现逻辑

import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.dataview.ListView;
import org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
import org.apache.flink.table.types.DataType;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

public class GroupConcatFullUDF extends BuiltInAggregateFunction<String, GroupConcatFullUDF.GroupConcatAccumulator> {

    private static final String DEFAULT_SEPARATOR = ",";

    @Override
    public DataType getOutputDataType() {
        return DataTypes.STRING();
    }

    @Override
    public List<DataType> getArgumentDataTypes() {
        List<DataType> types = new ArrayList<>(3);
        types.add(DataTypes.STRING());
        types.add(DataTypes.STRING());
        types.add(DataTypes.BOOLEAN());
        return types;
    }

    @Override
    public DataType getAccumulatorDataType() {
        return DataTypes.STRUCTURED(
                GroupConcatUDF.GroupConcatAccumulator.class,
                DataTypes.FIELD(
                        "list",
                        ListView.newListViewDataType(
                                DataTypes.STRING())),
                DataTypes.FIELD(
                        "retractList",
                        ListView.newListViewDataType(
                                DataTypes.STRING())));
    }

    public void resetAccumulator(GroupConcatUDF.GroupConcatAccumulator acc) {
        acc.list.clear();
        acc.retractList.clear();
    }

    @Override
    public GroupConcatAccumulator createAccumulator() {
        final GroupConcatAccumulator acc = new GroupConcatAccumulator();
        acc.list = new ListView<>();
        acc.retractList = new ListView<>();
        return acc;
    }

    /** 去重场景下暂时不支持回撤,后续优化 **/
    public void retract(GroupConcatAccumulator acc, String value) throws Exception {
        if (value != null) {
            if (!acc.list.remove(value)) {
                acc.retractList.add(value);
            }
        }
    }

    public void merge(
            GroupConcatAccumulator acc, Iterable<GroupConcatAccumulator> its)
            throws Exception {
        for (GroupConcatAccumulator otherAcc : its) {
            // merge list of acc and other
            List<String> buffer = new ArrayList<>();
            for (String binaryString : acc.list.get()) {
                buffer.add(binaryString);
            }
            for (String binaryString : otherAcc.list.get()) {
                buffer.add(binaryString);
            }
            // merge retract list of acc and other
            List<String> retractBuffer = new ArrayList<>();
            for (String binaryString : acc.retractList.get()) {
                retractBuffer.add(binaryString);
            }
            for (String binaryString : otherAcc.retractList.get()) {
                retractBuffer.add(binaryString);
            }

            // merge list & retract list
            List<String> newRetractBuffer = new ArrayList<>();
            for (String binaryString : retractBuffer) {
                if (!buffer.remove(binaryString)) {
                    newRetractBuffer.add(binaryString);
                }
            }

            // update to acc
            acc.list.clear();
            acc.list.addAll(buffer);
            acc.retractList.clear();
            acc.retractList.addAll(newRetractBuffer);
        }
    }

    public void accumulate(GroupConcatAccumulator acc, String value) throws Exception {
        accumulate(acc, DEFAULT_SEPARATOR, value, Boolean.FALSE);
    }

    public void accumulate(GroupConcatAccumulator acc, String value, boolean deduplication) throws Exception {
        accumulate(acc, DEFAULT_SEPARATOR, value, deduplication);
    }

    public void accumulate(GroupConcatAccumulator acc, String separator, String value) throws Exception {
        accumulate(acc, separator, value, Boolean.FALSE);
    }

    public void accumulate(GroupConcatAccumulator acc, String separator, String value, boolean deduplication) throws Exception {
        if (Objects.isNull(value)) {
            return;
        }
        if (!acc.list.getList().isEmpty()) {
            /*判断重复逻辑需要更好的方式,后续优化*/
            if (deduplication && acc.list.getList().contains(value)) {
                return;
            }
            acc.list.add(separator);
        }
        acc.list.add(value);
    }

    @Override
    public String getValue(GroupConcatAccumulator acc) {
        if (Objects.nonNull(acc.list)) {
            StringBuilder sb = new StringBuilder();
            acc.list.getList().forEach(sb::append);
            return sb.toString();
        } else {
            return null;
        }
    }

    public static class GroupConcatAccumulator {
        public ListView<String> list;
        public ListView<String> retractList;

        public GroupConcatAccumulator() {
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            GroupConcatAccumulator that = (GroupConcatAccumulator) o;
            return Objects.equals(list, that.list) && Objects.equals(retractList, that.retractList);
        }

        @Override
        public int hashCode() {
            return Objects.hash(list, retractList);
        }
    }
	/**分组后以分号连接数据,支持去重**/
    public static class GroupConcatDedupUDF extends GroupConcatFullUDF{
        @Override
        public List<DataType> getArgumentDataTypes() {
            List<DataType> types = new ArrayList<>(2);
            types.add(DataTypes.STRING());
            types.add(DataTypes.BOOLEAN());
            return types;
        }
    }
	/**分组后以指定符号连接数据,不去重**/
    public static class GroupConcatSepUDF extends GroupConcatFullUDF{
        @Override
        public List<DataType> getArgumentDataTypes() {
            List<DataType> types = new ArrayList<>(2);
            types.add(DataTypes.STRING());
            types.add(DataTypes.STRING());
            return types;
        }
    }

	/**分组后以分号连接数据,不去重**/
    public static class GroupConcatUDF extends GroupConcatFullUDF{
        @Override
        public List<DataType> getArgumentDataTypes() {
            return Collections.singletonList(DataTypes.STRING());
        }
    }
}

示例一:

CREATE TABLE print_table (
   `supplier_id` STRING,
   `total` STRING
) WITH (
 'connector' = 'print'
);

insert into 
SELECT supplier_id, GroupConcatUDF(product_id) AS total
FROM (VALUES
    ('supplier1', 'product1', 4),
    ('supplier1', 'product1', 3),
    ('supplier2', 'product3', 3),
    ('supplier2', 'product4', 4))
AS Products(supplier_id, product_id, rating)
GROUP BY supplier_id;

结果

2> +I[supplier1, product1]
12> +I[supplier2, product3]
2> -U[supplier1, product1]
2> +U[supplier1, product1,product1]
12> -U[supplier2, product3]
12> +U[supplier2, product3,product4]

示例二:

CREATE TABLE print_table (
   `supplier_id` STRING,
   `total` STRING
) WITH (
 'connector' = 'print'
);

insert into print_table
SELECT supplier_id, GroupConcatSepUDF('&',product_id) AS total
FROM (VALUES
    ('supplier1', 'product1', 4),
    ('supplier1', 'product1', 3),
    ('supplier2', 'product3', 3),
    ('supplier2', 'product4', 4))
AS Products(supplier_id, product_id, rating)
GROUP BY supplier_id;

结果

2> +I[supplier1, product1]
12> +I[supplier2, product3]
2> -U[supplier1, product1]
2> +U[supplier1, product1&product1]
12> -U[supplier2, product3]
12> +U[supplier2, product3&product4]

示例三:

CREATE TABLE print_table (
   `supplier_id` STRING,
   `total` STRING
) WITH (
 'connector' = 'print'
);

insert into print_table
SELECT supplier_id, GroupConcatDedupUDF(product_id,true) AS total
FROM (VALUES
    ('supplier1', 'product1', 4),
    ('supplier1', 'product1', 3),
    ('supplier2', 'product3', 3),
    ('supplier2', 'product4', 4))
AS Products(supplier_id, product_id, rating)
GROUP BY supplier_id;

结果

12> +I[supplier2, product3]
2> +I[supplier1, product1]
2> -U[supplier1, product1]
2> +U[supplier1, product1]
12> -U[supplier2, product3]
12> +U[supplier2, product3,product4]

示例四:

CREATE TABLE print_table (
   `supplier_id` STRING,
   `total` STRING
) WITH (
 'connector' = 'print'
);

insert into print_table
SELECT supplier_id, GroupConcatFullUDF('&',product_id,true) AS total
FROM (VALUES
    ('supplier1', 'product1', 4),
    ('supplier1', 'product1', 3),
    ('supplier2', 'product3', 3),
    ('supplier2', 'product4', 4))
AS Products(supplier_id, product_id, rating)
GROUP BY supplier_id;

结果

12> +I[supplier2, product3]
2> +I[supplier1, product1]
2> -U[supplier1, product1]
2> +U[supplier1, product1]
12> -U[supplier2, product3]
12> +U[supplier2, product3&product4]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值