给DuckDB 2025年5月 c++插件模板添加聚合函数mpz_sum

为DuckDB插件添加mpz_sum聚合函数

编写聚合函数的实现文件mpz.cpp, 需要包含前面一大堆头文件,否则就会发生编译错误variable 'duckdb::CreateAggregateFunctionInfo info' has initializer but incomplete type

// UDAF to compute sum and count of a relation
#pragma once
#include "duckdb.hpp"
#include "duckdb/function/function_set.hpp"
#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp"
#include "duckdb/function/scalar/nested_functions.hpp"
//#include "duckdb/core_functions/aggregate/nested_functions.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/common/pair.hpp"
#include "duckdb/planner/expression/bound_function_expression.hpp"
#include "duckdb/common/types/vector.hpp"
#include "gmp.h"

#include <map>
//namespace quackml {
struct MPZSumState {
    mpz_t sum;  // 使用GMP大整数存储总和
};

struct MPZSumFunction {
    template <class STATE>
    static void Initialize(STATE &state) {
        mpz_init(state.sum);  // 初始化GMP整数
    }

    template <class STATE>
    static void Destroy(STATE &state, duckdb::AggregateInputData &aggr_input_data) {
        mpz_clear(state.sum);  // 清理GMP整数
    }

    static bool IgnoreNull() { 
        return true; 
    }
};

// 从string_t安全读取数字字符串
static std::string GetNumericString(const duckdb::string_t& input) {
    const auto* raw = reinterpret_cast<const duckdb_string_t*>(&input);
    
    // 获取字符串指针和长度
    const char* data;
    uint32_t length;
    
    if (raw->value.inlined.length <= 12) {
        data = raw->value.inlined.inlined;
        length = raw->value.inlined.length;
    } else {
        data = raw->value.pointer.ptr;
        length = raw->value.pointer.length;
    }
    
    // 验证纯数字
    for (uint32_t i = 0; i < length; i++) {
        if (data[i] < '0' || data[i] > '9') {
            throw std::runtime_error("Invalid character in number");
        }
    }
    
    return std::string(data, length);
}

static void MPZSumUpdate(duckdb::Vector inputs[], duckdb::AggregateInputData &, idx_t input_count, duckdb::Vector &state_vector, idx_t count) {
    auto &input = inputs[0];
    duckdb::UnifiedVectorFormat sdata;
    state_vector.ToUnifiedFormat(count, sdata);
    duckdb::UnifiedVectorFormat input_data;
    input.ToUnifiedFormat(count, input_data);

    auto states = (MPZSumState **)sdata.data;
    for (idx_t i = 0; i < count; i++) {
        if (input_data.validity.RowIsValid(input_data.sel->get_index(i))) {
            auto &state = *states[sdata.sel->get_index(i)];
            auto str_value = duckdb::UnifiedVectorFormat::GetData<duckdb::string_t>(input_data);
            
            try {
                // 使用安全读数函数获取数字字符串
                std::string num_str = GetNumericString(str_value[input_data.sel->get_index(i)]);
                
                mpz_t tmp;
                mpz_init(tmp);
                // 使用更安全的字符串转换函数
                if (mpz_set_str(tmp, num_str.c_str(), 10) != 0) {
                    mpz_clear(tmp);
                    throw std::runtime_error("Failed to convert string to GMP number");
                }
                mpz_add(state.sum, state.sum, tmp);
                mpz_clear(tmp);
            } catch (const std::exception &e) {
                // 处理无效数字字符串
                throw std::runtime_error("Error processing number: " + std::string(e.what()));
            }
        }
    }
}

static void MPZSumFinalize(duckdb::Vector &state_vector, duckdb::AggregateInputData &, duckdb::Vector &result, idx_t count, idx_t offset) {
    duckdb::UnifiedVectorFormat sdata;
    state_vector.ToUnifiedFormat(count, sdata);
    auto states = (MPZSumState **)sdata.data;

    for (idx_t i = 0; i < count; i++) {
        const auto rid = i + offset;
        auto &state = *states[sdata.sel->get_index(i)];
        
        // 使用更可靠的GMP字符串转换
        char *str = mpz_get_str(nullptr, 10, state.sum);
        if (!str) {
            throw std::runtime_error("Failed to convert GMP number to string");
        }
        
        try {
            duckdb::string_t result_str(str);
            duckdb::FlatVector::GetData<duckdb::string_t>(result)[rid] = 
                duckdb::StringVector::AddString(result, result_str);
            free(str);
        } catch (...) {
            free(str);
            throw;
        }
    }
}
static void MPZSumCombine(duckdb::Vector &state_vector, duckdb::Vector &combined, duckdb::AggregateInputData &, idx_t count) {
    duckdb::UnifiedVectorFormat sdata;
    state_vector.ToUnifiedFormat(count, sdata);
    auto states_ptr = (MPZSumState **)sdata.data;
    auto combined_ptr = duckdb::FlatVector::GetData<MPZSumState *>(combined);

    for (idx_t i = 0; i < count; i++) {
        auto &state = *states_ptr[sdata.sel->get_index(i)];
        mpz_add(combined_ptr[i]->sum, combined_ptr[i]->sum, state.sum);  // 合并GMP整数
    }
}

duckdb::unique_ptr<duckdb::FunctionData> MPZSumBind(duckdb::ClientContext &context, duckdb::AggregateFunction &function, duckdb::vector<duckdb::unique_ptr<duckdb::Expression>> &arguments) {
    function.return_type = duckdb::LogicalType::VARCHAR;  // 返回字符串类型
    return nullptr;
}

duckdb::AggregateFunction GetMPZSumFunction() {
    using STATE_TYPE = MPZSumState;

    return duckdb::AggregateFunction(
        "mpz_sum",                                                                 // 函数名
        {duckdb::LogicalType::VARCHAR},                                            // 参数类型为字符串
        duckdb::LogicalType::VARCHAR,                                              // 返回类型为字符串
        duckdb::AggregateFunction::StateSize<STATE_TYPE>,                          // 状态大小
        duckdb::AggregateFunction::StateInitialize<STATE_TYPE, MPZSumFunction>,    // 初始化
        MPZSumUpdate,                                                              // 更新
        MPZSumCombine,                                                             // 合并
        MPZSumFinalize,                                                            // 最终化
        nullptr,                                                                   // 简单更新
        MPZSumBind,                                                                // 绑定
        duckdb::AggregateFunction::StateDestroy<STATE_TYPE, MPZSumFunction>        // 销毁
    );
}
/*
void MPZSum::RegisterFunction(duckdb::Connection &conn, duckdb::Catalog &catalog) {
    duckdb::AggregateFunctionSet mpz_sum("mpz_sum");
    mpz_sum.AddFunction(GetMPZSumFunction());
    duckdb::CreateAggregateFunctionInfo info(mpz_sum);
    catalog.CreateFunction(*conn.context, info);
}
*/
//}

在quack_extension.cpp中预处理部分最后添加#include "mpz.cpp"
将上面注释掉的注册函数内容复制到quack_extension.cpp的static void LoadInternal(DatabaseInstance &instance) 函数最后, 注意将最后一行改写成LoadInternal前面部分调用的ExtensionUtil::RegisterFunction函数。
因为quack_extension.cpp的函数都在namespace duckdb中,所以函数前面的duckdb::可省略。

    AggregateFunctionSet mpz_sum("mpz_sum");
    mpz_sum.AddFunction(GetMPZSumFunction());
    CreateAggregateFunctionInfo info(mpz_sum);
    ExtensionUtil::RegisterFunction(instance, info);

正常编译、添加元数据即可。注意链接gmp动态库。注意将libduckdb.so动态库目录加入LD_LIBRARY_PATH。

export LD_LIBRARY_PATH=/par

g++ -fPIC -shared -o libtest2.so quack_extension.cpp -I /par/duck/src/include -lssl -lcrypto -I include -lduckdb -L /par/duck/build/src -I /par/duckdb-0.10.3/include  -lgmp
root@6ae32a5ffcde:/par/agg# python3 /par/appendmetadata.py -l libtest2.so -n quack -dv v1.3.0  --duckdb-platform linux_amd64 --extension-version 0.1 --abi-type ""
Creating extension binary:
 - Input file: libtest2.so
 - Output file: quack.duckdb_extension
 - Metadata:
   - FIELD8 (unused)            = EMPTY
   - FIELD7 (unused)            = EMPTY
   - FIELD6 (unused)            = EMPTY
   - FIELD5 (abi_type)          =
   - FIELD4 (extension_version) = 0.1
   - FIELD3 (duckdb_version)    = v1.3.0
   - FIELD2 (duckdb_platform)   = linux_amd64
   - FIELD1 (header signature)  = 4 (special value to identify a duckdb extension)
root@6ae32a5ffcde:/par/agg# /par/duckdb130 -unsigned
DuckDB v1.3.0 (Ossivalis) 71c5c07cdd
Enter ".help" for usage hints.
D load '/par/agg/quack.duckdb_extension';
D select mpz_sum(i::varchar)from range(10000000,20000000)t(i);
┌─────────────────────────────┐
│ mpz_sum(CAST(i AS VARCHAR)) │
│           varchar           │
├─────────────────────────────┤
│ 149999995000000             │
└─────────────────────────────┘

说明:本文在2025年5月的extension-template-main.zip模板中测试通过,2025年9月的版本更改了接口,不再使用ExtensionUtil注册函数,需要进一步修改。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值