编写聚合函数的实现文件mpz.cpp, 需要包含前面一大堆头文件,否则就会发生编译错误variable 'duckdb::CreateAggregateFunctionInfo info' has initializer but incomplete type
// UDAF to compute sum and count of a relation
#pragma once
#include "duckdb.hpp"
#include "duckdb/function/function_set.hpp"
#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp"
#include "duckdb/function/scalar/nested_functions.hpp"
//#include "duckdb/core_functions/aggregate/nested_functions.hpp"
#include "duckdb/planner/expression/bound_aggregate_expression.hpp"
#include "duckdb/common/pair.hpp"
#include "duckdb/planner/expression/bound_function_expression.hpp"
#include "duckdb/common/types/vector.hpp"
#include "gmp.h"
#include <map>
//namespace quackml {
struct MPZSumState {
mpz_t sum; // 使用GMP大整数存储总和
};
struct MPZSumFunction {
template <class STATE>
static void Initialize(STATE &state) {
mpz_init(state.sum); // 初始化GMP整数
}
template <class STATE>
static void Destroy(STATE &state, duckdb::AggregateInputData &aggr_input_data) {
mpz_clear(state.sum); // 清理GMP整数
}
static bool IgnoreNull() {
return true;
}
};
// 从string_t安全读取数字字符串
static std::string GetNumericString(const duckdb::string_t& input) {
const auto* raw = reinterpret_cast<const duckdb_string_t*>(&input);
// 获取字符串指针和长度
const char* data;
uint32_t length;
if (raw->value.inlined.length <= 12) {
data = raw->value.inlined.inlined;
length = raw->value.inlined.length;
} else {
data = raw->value.pointer.ptr;
length = raw->value.pointer.length;
}
// 验证纯数字
for (uint32_t i = 0; i < length; i++) {
if (data[i] < '0' || data[i] > '9') {
throw std::runtime_error("Invalid character in number");
}
}
return std::string(data, length);
}
static void MPZSumUpdate(duckdb::Vector inputs[], duckdb::AggregateInputData &, idx_t input_count, duckdb::Vector &state_vector, idx_t count) {
auto &input = inputs[0];
duckdb::UnifiedVectorFormat sdata;
state_vector.ToUnifiedFormat(count, sdata);
duckdb::UnifiedVectorFormat input_data;
input.ToUnifiedFormat(count, input_data);
auto states = (MPZSumState **)sdata.data;
for (idx_t i = 0; i < count; i++) {
if (input_data.validity.RowIsValid(input_data.sel->get_index(i))) {
auto &state = *states[sdata.sel->get_index(i)];
auto str_value = duckdb::UnifiedVectorFormat::GetData<duckdb::string_t>(input_data);
try {
// 使用安全读数函数获取数字字符串
std::string num_str = GetNumericString(str_value[input_data.sel->get_index(i)]);
mpz_t tmp;
mpz_init(tmp);
// 使用更安全的字符串转换函数
if (mpz_set_str(tmp, num_str.c_str(), 10) != 0) {
mpz_clear(tmp);
throw std::runtime_error("Failed to convert string to GMP number");
}
mpz_add(state.sum, state.sum, tmp);
mpz_clear(tmp);
} catch (const std::exception &e) {
// 处理无效数字字符串
throw std::runtime_error("Error processing number: " + std::string(e.what()));
}
}
}
}
static void MPZSumFinalize(duckdb::Vector &state_vector, duckdb::AggregateInputData &, duckdb::Vector &result, idx_t count, idx_t offset) {
duckdb::UnifiedVectorFormat sdata;
state_vector.ToUnifiedFormat(count, sdata);
auto states = (MPZSumState **)sdata.data;
for (idx_t i = 0; i < count; i++) {
const auto rid = i + offset;
auto &state = *states[sdata.sel->get_index(i)];
// 使用更可靠的GMP字符串转换
char *str = mpz_get_str(nullptr, 10, state.sum);
if (!str) {
throw std::runtime_error("Failed to convert GMP number to string");
}
try {
duckdb::string_t result_str(str);
duckdb::FlatVector::GetData<duckdb::string_t>(result)[rid] =
duckdb::StringVector::AddString(result, result_str);
free(str);
} catch (...) {
free(str);
throw;
}
}
}
static void MPZSumCombine(duckdb::Vector &state_vector, duckdb::Vector &combined, duckdb::AggregateInputData &, idx_t count) {
duckdb::UnifiedVectorFormat sdata;
state_vector.ToUnifiedFormat(count, sdata);
auto states_ptr = (MPZSumState **)sdata.data;
auto combined_ptr = duckdb::FlatVector::GetData<MPZSumState *>(combined);
for (idx_t i = 0; i < count; i++) {
auto &state = *states_ptr[sdata.sel->get_index(i)];
mpz_add(combined_ptr[i]->sum, combined_ptr[i]->sum, state.sum); // 合并GMP整数
}
}
duckdb::unique_ptr<duckdb::FunctionData> MPZSumBind(duckdb::ClientContext &context, duckdb::AggregateFunction &function, duckdb::vector<duckdb::unique_ptr<duckdb::Expression>> &arguments) {
function.return_type = duckdb::LogicalType::VARCHAR; // 返回字符串类型
return nullptr;
}
duckdb::AggregateFunction GetMPZSumFunction() {
using STATE_TYPE = MPZSumState;
return duckdb::AggregateFunction(
"mpz_sum", // 函数名
{duckdb::LogicalType::VARCHAR}, // 参数类型为字符串
duckdb::LogicalType::VARCHAR, // 返回类型为字符串
duckdb::AggregateFunction::StateSize<STATE_TYPE>, // 状态大小
duckdb::AggregateFunction::StateInitialize<STATE_TYPE, MPZSumFunction>, // 初始化
MPZSumUpdate, // 更新
MPZSumCombine, // 合并
MPZSumFinalize, // 最终化
nullptr, // 简单更新
MPZSumBind, // 绑定
duckdb::AggregateFunction::StateDestroy<STATE_TYPE, MPZSumFunction> // 销毁
);
}
/*
void MPZSum::RegisterFunction(duckdb::Connection &conn, duckdb::Catalog &catalog) {
duckdb::AggregateFunctionSet mpz_sum("mpz_sum");
mpz_sum.AddFunction(GetMPZSumFunction());
duckdb::CreateAggregateFunctionInfo info(mpz_sum);
catalog.CreateFunction(*conn.context, info);
}
*/
//}
在quack_extension.cpp中预处理部分最后添加#include "mpz.cpp"
将上面注释掉的注册函数内容复制到quack_extension.cpp的static void LoadInternal(DatabaseInstance &instance) 函数最后, 注意将最后一行改写成LoadInternal前面部分调用的ExtensionUtil::RegisterFunction函数。
因为quack_extension.cpp的函数都在namespace duckdb中,所以函数前面的duckdb::可省略。
AggregateFunctionSet mpz_sum("mpz_sum");
mpz_sum.AddFunction(GetMPZSumFunction());
CreateAggregateFunctionInfo info(mpz_sum);
ExtensionUtil::RegisterFunction(instance, info);
正常编译、添加元数据即可。注意链接gmp动态库。注意将libduckdb.so动态库目录加入LD_LIBRARY_PATH。
export LD_LIBRARY_PATH=/par
g++ -fPIC -shared -o libtest2.so quack_extension.cpp -I /par/duck/src/include -lssl -lcrypto -I include -lduckdb -L /par/duck/build/src -I /par/duckdb-0.10.3/include -lgmp
root@6ae32a5ffcde:/par/agg# python3 /par/appendmetadata.py -l libtest2.so -n quack -dv v1.3.0 --duckdb-platform linux_amd64 --extension-version 0.1 --abi-type ""
Creating extension binary:
- Input file: libtest2.so
- Output file: quack.duckdb_extension
- Metadata:
- FIELD8 (unused) = EMPTY
- FIELD7 (unused) = EMPTY
- FIELD6 (unused) = EMPTY
- FIELD5 (abi_type) =
- FIELD4 (extension_version) = 0.1
- FIELD3 (duckdb_version) = v1.3.0
- FIELD2 (duckdb_platform) = linux_amd64
- FIELD1 (header signature) = 4 (special value to identify a duckdb extension)
root@6ae32a5ffcde:/par/agg# /par/duckdb130 -unsigned
DuckDB v1.3.0 (Ossivalis) 71c5c07cdd
Enter ".help" for usage hints.
D load '/par/agg/quack.duckdb_extension';
D select mpz_sum(i::varchar)from range(10000000,20000000)t(i);
┌─────────────────────────────┐
│ mpz_sum(CAST(i AS VARCHAR)) │
│ varchar │
├─────────────────────────────┤
│ 149999995000000 │
└─────────────────────────────┘
说明:本文在2025年5月的extension-template-main.zip模板中测试通过,2025年9月的版本更改了接口,不再使用ExtensionUtil注册函数,需要进一步修改。
为DuckDB插件添加mpz_sum聚合函数


被折叠的 条评论
为什么被折叠?



