前文的插件,每个运算都需要一个函数,其中函数体大部分内容都重复,给他一个要实现的操作函数列表,连同文档起发给他,让他整合。
请把减法、除法、取模、平方、平方根等bn函数也实现,另外现在代码冗余很多,请合理地合并大部分相同代码,比如注册用同一个变量循环,调用函数时根据传入的额外参数决定调用哪个特定的bn底层函数,
他一开始给出的代码调用了不存在的duckdb函数,我把文档中存在的类似函数发给他,它自己选了一个代替就好了。
根据你提供的函数列表,应该使用以下语句代替:
// 设置操作类型作为extra_info
duckdb_scalar_function_set_extra_info(function, (void*)(uintptr_t)config.op_type, NULL);
// 在函数实现中获取extra_info
bn_operation_t op_type = (bn_operation_t)(uintptr_t)duckdb_scalar_function_get_extra_info(info);
使用 duckdb_scalar_function_set_extra_info 和 duckdb_scalar_function_get_extra_info 来传递和获取操作类型参数。
编译成功了,但是我的openssl库中没有BN_mod函数,我看文档中说BN_mod() corresponds to BN_div() with dv set to NULL.就改成了case BN_OP_MOD:
success = BN_div(NULL, rem, a, b, ctx);但是执行结果好像不对。select bn_mod(‘1’,‘2’);
0 ,而duckdb内置函数select mod(1,2);
1 ,请修改正确,不做别的
问题在于 BN_div 的使用方式。当 dv 设置为 NULL 时,rem 参数应该接收余数结果。正确的修改如下:
case BN_OP_MOD:
rem = BN_new();
success = BN_div(NULL, rem, a, b, ctx);
if (success) {
// 将余数结果复制到result中
success = BN_copy(result, rem);
}
if (rem) BN_free(rem);
break;
这样修改后,bn_mod('1','2') 应该返回 1,与 DuckDB 内置的 mod(1,2) 结果一致。
我把代码复制粘贴到原代码,执行bn_mod总是内存错误,我又让他单独写了一个测试bn_mod的程序,没有问题。
将它与前面调用相同的BN_div函数的除法代码比较,少了一行rem = BN_new();,
可笑的是,ds的代码就有rem = BN_new();,我粘贴漏了,白白浪费了很多功夫。
正确的完整代码如下
#include "duckdb_extension.h"
#include <string.h>
#include <stdlib.h>
// 声明BN函数(来自OpenSSL)
//extern "C" {
typedef struct bignum_st BIGNUM;
typedef struct bn_ctx_st BN_CTX;
BIGNUM* BN_new();
int BN_copy(BIGNUM*, const BIGNUM*);
void BN_free(BIGNUM*);
int BN_dec2bn(BIGNUM**, const char*);
char* BN_bn2dec(const BIGNUM*);
int BN_add(BIGNUM*, const BIGNUM*, const BIGNUM*);
int BN_sub(BIGNUM*, const BIGNUM*, const BIGNUM*);
int BN_mul(BIGNUM*, const BIGNUM*, const BIGNUM*, BN_CTX*);
int BN_div(BIGNUM*, BIGNUM*, const BIGNUM*, const BIGNUM*, BN_CTX*);
//int BN_mod(BIGNUM*, const BIGNUM*, const BIGNUM*, BN_CTX*);
int BN_sqr(BIGNUM*, const BIGNUM*, BN_CTX*);
BIGNUM* BN_mod_sqrt(BIGNUM*, const BIGNUM*, const BIGNUM*, BN_CTX*);
int BN_exp(BIGNUM*, const BIGNUM*, const BIGNUM*, BN_CTX*);
int BN_mod_exp(BIGNUM*, const BIGNUM*, const BIGNUM*, const BIGNUM*, BN_CTX*);
int BN_gcd(BIGNUM*, const BIGNUM*, const BIGNUM*, BN_CTX*);
BN_CTX* BN_CTX_new();
void BN_CTX_free(BN_CTX*);
void free(void*);
//}
DUCKDB_EXTENSION_EXTERN
// 支持的BN操作类型
typedef enum {
BN_OP_ADD, // 加法
BN_OP_SUB, // 减法
BN_OP_MUL, // 乘法
BN_OP_DIV, // 除法(返回商)
BN_OP_MOD, // 取模
BN_OP_SQR, // 平方
BN_OP_SQRT, // 平方根
BN_OP_EXP, // 指数
BN_OP_MOD_EXP, // 模指数
BN_OP_GCD // 最大公约数
} bn_operation_t;
// 从duckdb_string_t提取C字符串
static char* extract_string(duckdb_string_t str) {
char* result = NULL;
if (duckdb_string_is_inlined(str)) {
result = duckdb_malloc(str.value.inlined.length + 1);
memcpy(result, str.value.inlined.inlined, str.value.inlined.length);
result[str.value.inlined.length] = '\0';
} else {
result = duckdb_malloc(str.value.pointer.length + 1);
memcpy(result, str.value.pointer.ptr, str.value.pointer.length);
result[str.value.pointer.length] = '\0';
}
return result;
}
// 设置duckdb_string_t为C字符串
static void set_string(duckdb_string_t* str, const char* cstr) {
size_t len = strlen(cstr);
if (len <= 12) { // 内联字符串最大长度
str->value.inlined.length = len;
memcpy(str->value.inlined.inlined, cstr, len);
str->value.inlined.inlined[len] = '\0';
} else {
char* ptr = duckdb_malloc(len + 1);
memcpy(ptr, cstr, len);
ptr[len] = '\0';
str->value.pointer.ptr = ptr;
str->value.pointer.length = len;
}
}
// 通用的BN操作函数
static void BNGenericFunction(duckdb_function_info info, duckdb_data_chunk input, duckdb_vector output) {
//bn_operation_t op_type = (bn_operation_t)duckdb_function_info_get_bind_data(info);
// 在函数实现中获取extra_info
bn_operation_t op_type = (bn_operation_t)(uintptr_t)duckdb_scalar_function_get_extra_info(info);
idx_t input_size = duckdb_data_chunk_get_size(input);
duckdb_vector a_vec = duckdb_data_chunk_get_vector(input, 0);
duckdb_string_t* a_data = (duckdb_string_t*)duckdb_vector_get_data(a_vec);
duckdb_string_t* result_data = (duckdb_string_t*)duckdb_vector_get_data(output);
// 获取有效性位图
uint64_t* a_validity = duckdb_vector_get_validity(a_vec);
uint64_t* b_validity = NULL;
uint64_t* c_validity = NULL;
duckdb_string_t* b_data = NULL;
duckdb_string_t* c_data = NULL;
// 根据操作类型获取不同的输入向量
if (op_type != BN_OP_SQR && op_type != BN_OP_SQRT) {
duckdb_vector b_vec = duckdb_data_chunk_get_vector(input, 1);
b_data = (duckdb_string_t*)duckdb_vector_get_data(b_vec);
b_validity = duckdb_vector_get_validity(b_vec);
}
if (op_type == BN_OP_MOD_EXP) {
duckdb_vector c_vec = duckdb_data_chunk_get_vector(input, 2);
c_data = (duckdb_string_t*)duckdb_vector_get_data(c_vec);
c_validity = duckdb_vector_get_validity(c_vec);
}
BN_CTX* ctx = BN_CTX_new();
if (!ctx) return;
for (idx_t row = 0; row < input_size; row++) {
// 检查NULL值
bool has_null = (a_validity && !duckdb_validity_row_is_valid(a_validity, row)) ||
(b_validity && !duckdb_validity_row_is_valid(b_validity, row)) ||
(c_validity && !duckdb_validity_row_is_valid(c_validity, row));
if (has_null) {
duckdb_vector_ensure_validity_writable(output);
uint64_t* result_validity = duckdb_vector_get_validity(output);
duckdb_validity_set_row_invalid(result_validity, row);
continue;
}
BIGNUM* a = BN_new();
BIGNUM* b = NULL;
BIGNUM* c = NULL;
BIGNUM* result = BN_new();
BIGNUM* rem = NULL;
char* a_str = NULL;
char* b_str = NULL;
char* c_str = NULL;
char* res_str = NULL;
int success = 0;
if (!a || !result) {
goto cleanup;
}
// 提取输入字符串
a_str = extract_string(a_data[row]);
if (b_data) {
b_str = extract_string(b_data[row]);
b = BN_new();
if (!b) goto cleanup;
}
if (c_data) {
c_str = extract_string(c_data[row]);
c = BN_new();
if (!c) goto cleanup;
}
// 将字符串转换为大数
if (!BN_dec2bn(&a, a_str) ||
(b && !BN_dec2bn(&b, b_str)) ||
(c && !BN_dec2bn(&c, c_str))) {
goto cleanup;
}
// 根据操作类型执行相应的BN函数
switch (op_type) {
case BN_OP_ADD:
success = BN_add(result, a, b);
break;
case BN_OP_SUB:
success = BN_sub(result, a, b);
break;
case BN_OP_MUL:
success = BN_mul(result, a, b, ctx);
break;
case BN_OP_DIV:
rem = BN_new();
success = BN_div(result, rem, a, b, ctx);
if (rem) BN_free(rem);
break;
case BN_OP_MOD:
rem = BN_new();
success = BN_div(NULL , rem, a, b, ctx);//BN_mod(result, a, b, ctx);
if (success) {
// 将余数结果复制到result中
success = BN_copy(result, rem);
}
if (rem) BN_free(rem);
break;
case BN_OP_SQR:
success = BN_sqr(result, a, ctx);
break;
case BN_OP_SQRT:
{
BIGNUM* sqrt_result = BN_mod_sqrt(result, a, b, ctx);
success = (sqrt_result != NULL);
}
break;
case BN_OP_EXP:
success = BN_exp(result, a, b, ctx);
break;
case BN_OP_MOD_EXP:
success = BN_mod_exp(result, a, b, c, ctx);
break;
case BN_OP_GCD:
success = BN_gcd(result, a, b, ctx);
break;
default:
success = 0;
break;
}
if (!success) {
goto cleanup;
}
// 将结果转换为字符串
res_str = BN_bn2dec(result);
if (res_str) {
set_string(&result_data[row], res_str);
free(res_str);
}
cleanup:
if (a_str) duckdb_free(a_str);
if (b_str) duckdb_free(b_str);
if (c_str) duckdb_free(c_str);
if (a) BN_free(a);
if (b) BN_free(b);
if (c) BN_free(c);
if (result) BN_free(result);
}
BN_CTX_free(ctx);
}
// 函数配置信息
typedef struct {
const char* name;
bn_operation_t op_type;
int num_params;
} bn_function_config;
static bn_function_config bn_functions[] = {
{"bn_add", BN_OP_ADD, 2},
{"bn_sub", BN_OP_SUB, 2},
{"bn_mul", BN_OP_MUL, 2},
{"bn_div", BN_OP_DIV, 2},
{"bn_mod", BN_OP_MOD, 2},
{"bn_sqr", BN_OP_SQR, 1},
{"bn_sqrt", BN_OP_SQRT, 2},
{"bn_exp", BN_OP_EXP, 2},
{"bn_mod_exp", BN_OP_MOD_EXP, 3},
{"bn_gcd", BN_OP_GCD, 2}
};
// 注册所有BN函数
void RegisterBNFunctions(duckdb_connection connection) {
duckdb_logical_type varchar_type = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR);
for (size_t i = 0; i < sizeof(bn_functions) / sizeof(bn_functions[0]); i++) {
bn_function_config config = bn_functions[i];
duckdb_scalar_function function = duckdb_create_scalar_function();
duckdb_scalar_function_set_name(function, config.name);
// 添加参数
for (int j = 0; j < config.num_params; j++) {
duckdb_scalar_function_add_parameter(function, varchar_type);
}
// 设置返回类型
duckdb_scalar_function_set_return_type(function, varchar_type);
// 设置操作类型作为bind_data
// duckdb_function_info info = duckdb_scalar_function_get_function_info(function);
// duckdb_function_info_set_bind_data(info, (void*)(uintptr_t)config.op_type);
// 设置操作类型作为extra_info
duckdb_scalar_function_set_extra_info(function, (void*)(uintptr_t)config.op_type, NULL);
// 设置函数实现
duckdb_scalar_function_set_function(function, BNGenericFunction);
// 注册函数
duckdb_register_scalar_function(connection, function);
duckdb_destroy_scalar_function(&function);
}
duckdb_destroy_logical_type(&varchar_type);
}
编译运行成功。
/par/cext# gcc -fPIC -shared -o libtest2.so *.c -I . -lssl -lcrypto
/par/cext# python3 ./appmeta.py -l libtest2.so -n add -dv v1.2.0 --duckdb-platform linux_arm64 --extension-version 0.1
Creating extension binary:
- Input file: libtest2.so
- Output file: add.duckdb_extension
/par/cext# /par/duckdb140 -unsigned
DuckDB v1.4.0 (Andium) b8a06e4a22
Enter ".help" for usage hints.
D load '/par/cext/add.duckdb_extension';
D select bn_mod('256','7');
┌────────────────────┐
│ bn_mod('256', '7') │
│ varchar │
├────────────────────┤
│ 4 │
└────────────────────┘

3708

被折叠的 条评论
为什么被折叠?



