DeepSeek辅助整合精简的DuckDB大数运算插件

前文的插件,每个运算都需要一个函数,其中函数体大部分内容都重复,给他一个要实现的操作函数列表,连同文档起发给他,让他整合。

请把减法、除法、取模、平方、平方根等bn函数也实现,另外现在代码冗余很多,请合理地合并大部分相同代码,比如注册用同一个变量循环,调用函数时根据传入的额外参数决定调用哪个特定的bn底层函数,

他一开始给出的代码调用了不存在的duckdb函数,我把文档中存在的类似函数发给他,它自己选了一个代替就好了。

根据你提供的函数列表,应该使用以下语句代替:

// 设置操作类型作为extra_info
duckdb_scalar_function_set_extra_info(function, (void*)(uintptr_t)config.op_type, NULL);

// 在函数实现中获取extra_info
bn_operation_t op_type = (bn_operation_t)(uintptr_t)duckdb_scalar_function_get_extra_info(info);

使用 duckdb_scalar_function_set_extra_infoduckdb_scalar_function_get_extra_info 来传递和获取操作类型参数。

编译成功了,但是我的openssl库中没有BN_mod函数,我看文档中说BN_mod() corresponds to BN_div() with dv set to NULL.就改成了case BN_OP_MOD:
success = BN_div(NULL, rem, a, b, ctx);但是执行结果好像不对。select bn_mod(‘1’,‘2’);
0 ,而duckdb内置函数select mod(1,2);
1 ,请修改正确,不做别的

问题在于 BN_div 的使用方式。当 dv 设置为 NULL 时,rem 参数应该接收余数结果。正确的修改如下:

case BN_OP_MOD:
    rem = BN_new();
    success = BN_div(NULL, rem, a, b, ctx);
    if (success) {
        // 将余数结果复制到result中
        success = BN_copy(result, rem);
    }
    if (rem) BN_free(rem);
    break;

这样修改后,bn_mod('1','2') 应该返回 1,与 DuckDB 内置的 mod(1,2) 结果一致。

我把代码复制粘贴到原代码,执行bn_mod总是内存错误,我又让他单独写了一个测试bn_mod的程序,没有问题。
将它与前面调用相同的BN_div函数的除法代码比较,少了一行rem = BN_new();
可笑的是,ds的代码就有rem = BN_new();,我粘贴漏了,白白浪费了很多功夫。

正确的完整代码如下

#include "duckdb_extension.h"
#include <string.h>
#include <stdlib.h>
// 声明BN函数(来自OpenSSL)
//extern "C" {
    typedef struct bignum_st BIGNUM;
    typedef struct bn_ctx_st BN_CTX;
    
    BIGNUM* BN_new();
    int BN_copy(BIGNUM*, const BIGNUM*);
    void BN_free(BIGNUM*);
    int BN_dec2bn(BIGNUM**, const char*);
    char* BN_bn2dec(const BIGNUM*);
    int BN_add(BIGNUM*, const BIGNUM*, const BIGNUM*);
    int BN_sub(BIGNUM*, const BIGNUM*, const BIGNUM*);
    int BN_mul(BIGNUM*, const BIGNUM*, const BIGNUM*, BN_CTX*);
    int BN_div(BIGNUM*, BIGNUM*, const BIGNUM*, const BIGNUM*, BN_CTX*);
    //int BN_mod(BIGNUM*, const BIGNUM*, const BIGNUM*, BN_CTX*);
    int BN_sqr(BIGNUM*, const BIGNUM*, BN_CTX*);
    BIGNUM* BN_mod_sqrt(BIGNUM*, const BIGNUM*, const BIGNUM*, BN_CTX*);
    int BN_exp(BIGNUM*, const BIGNUM*, const BIGNUM*, BN_CTX*);
    int BN_mod_exp(BIGNUM*, const BIGNUM*, const BIGNUM*, const BIGNUM*, BN_CTX*);
    int BN_gcd(BIGNUM*, const BIGNUM*, const BIGNUM*, BN_CTX*);
    BN_CTX* BN_CTX_new();
    void BN_CTX_free(BN_CTX*);
    void free(void*);
//}

DUCKDB_EXTENSION_EXTERN

// 支持的BN操作类型
typedef enum {
    BN_OP_ADD,      // 加法
    BN_OP_SUB,      // 减法
    BN_OP_MUL,      // 乘法
    BN_OP_DIV,      // 除法(返回商)
    BN_OP_MOD,      // 取模
    BN_OP_SQR,      // 平方
    BN_OP_SQRT,     // 平方根
    BN_OP_EXP,      // 指数
    BN_OP_MOD_EXP,  // 模指数
    BN_OP_GCD       // 最大公约数
} bn_operation_t;

// 从duckdb_string_t提取C字符串
static char* extract_string(duckdb_string_t str) {
    char* result = NULL;
    if (duckdb_string_is_inlined(str)) {
        result = duckdb_malloc(str.value.inlined.length + 1);
        memcpy(result, str.value.inlined.inlined, str.value.inlined.length);
        result[str.value.inlined.length] = '\0';
    } else {
        result = duckdb_malloc(str.value.pointer.length + 1);
        memcpy(result, str.value.pointer.ptr, str.value.pointer.length);
        result[str.value.pointer.length] = '\0';
    }
    return result;
}

// 设置duckdb_string_t为C字符串
static void set_string(duckdb_string_t* str, const char* cstr) {
    size_t len = strlen(cstr);
    if (len <= 12) { // 内联字符串最大长度
        str->value.inlined.length = len;
        memcpy(str->value.inlined.inlined, cstr, len);
        str->value.inlined.inlined[len] = '\0';
    } else {
        char* ptr = duckdb_malloc(len + 1);
        memcpy(ptr, cstr, len);
        ptr[len] = '\0';
        str->value.pointer.ptr = ptr;
        str->value.pointer.length = len;
    }
}

// 通用的BN操作函数
static void BNGenericFunction(duckdb_function_info info, duckdb_data_chunk input, duckdb_vector output) {
    //bn_operation_t op_type = (bn_operation_t)duckdb_function_info_get_bind_data(info);
    // 在函数实现中获取extra_info
    bn_operation_t op_type = (bn_operation_t)(uintptr_t)duckdb_scalar_function_get_extra_info(info);
    
    idx_t input_size = duckdb_data_chunk_get_size(input);
    duckdb_vector a_vec = duckdb_data_chunk_get_vector(input, 0);
    duckdb_string_t* a_data = (duckdb_string_t*)duckdb_vector_get_data(a_vec);
    duckdb_string_t* result_data = (duckdb_string_t*)duckdb_vector_get_data(output);
    
    // 获取有效性位图
    uint64_t* a_validity = duckdb_vector_get_validity(a_vec);
    uint64_t* b_validity = NULL;
    uint64_t* c_validity = NULL;
    
    duckdb_string_t* b_data = NULL;
    duckdb_string_t* c_data = NULL;
    
    // 根据操作类型获取不同的输入向量
    if (op_type != BN_OP_SQR && op_type != BN_OP_SQRT) {
        duckdb_vector b_vec = duckdb_data_chunk_get_vector(input, 1);
        b_data = (duckdb_string_t*)duckdb_vector_get_data(b_vec);
        b_validity = duckdb_vector_get_validity(b_vec);
    }
    
    if (op_type == BN_OP_MOD_EXP) {
        duckdb_vector c_vec = duckdb_data_chunk_get_vector(input, 2);
        c_data = (duckdb_string_t*)duckdb_vector_get_data(c_vec);
        c_validity = duckdb_vector_get_validity(c_vec);
    }

    BN_CTX* ctx = BN_CTX_new();
    if (!ctx) return;

    for (idx_t row = 0; row < input_size; row++) {
        // 检查NULL值
        bool has_null = (a_validity && !duckdb_validity_row_is_valid(a_validity, row)) ||
                       (b_validity && !duckdb_validity_row_is_valid(b_validity, row)) ||
                       (c_validity && !duckdb_validity_row_is_valid(c_validity, row));
        
        if (has_null) {
            duckdb_vector_ensure_validity_writable(output);
            uint64_t* result_validity = duckdb_vector_get_validity(output);
            duckdb_validity_set_row_invalid(result_validity, row);
            continue;
        }

        BIGNUM* a = BN_new();
        BIGNUM* b = NULL;
        BIGNUM* c = NULL;
        BIGNUM* result = BN_new();
        BIGNUM* rem = NULL;
        char* a_str = NULL;
        char* b_str = NULL;
        char* c_str = NULL;
        char* res_str = NULL;
        int success = 0;

        if (!a || !result) {
            goto cleanup;
        }

        // 提取输入字符串
        a_str = extract_string(a_data[row]);
        if (b_data) {
            b_str = extract_string(b_data[row]);
            b = BN_new();
            if (!b) goto cleanup;
        }
        if (c_data) {
            c_str = extract_string(c_data[row]);
            c = BN_new();
            if (!c) goto cleanup;
        }

        // 将字符串转换为大数
        if (!BN_dec2bn(&a, a_str) || 
            (b && !BN_dec2bn(&b, b_str)) || 
            (c && !BN_dec2bn(&c, c_str))) {
            goto cleanup;
        }

        // 根据操作类型执行相应的BN函数
        switch (op_type) {
            case BN_OP_ADD:
                success = BN_add(result, a, b);
                break;
            case BN_OP_SUB:
                success = BN_sub(result, a, b);
                break;
            case BN_OP_MUL:
                success = BN_mul(result, a, b, ctx);
                break;
            case BN_OP_DIV:
                rem = BN_new();
                success = BN_div(result, rem, a, b, ctx);
                if (rem) BN_free(rem);
                break;
            case BN_OP_MOD:
                rem = BN_new();
                success = BN_div(NULL , rem, a, b, ctx);//BN_mod(result, a, b, ctx);
                if (success) {
                   // 将余数结果复制到result中
                   success = BN_copy(result, rem);
                }
                if (rem) BN_free(rem);
                break;
            case BN_OP_SQR:
                success = BN_sqr(result, a, ctx);
                break;
            case BN_OP_SQRT:
                {
                    BIGNUM* sqrt_result = BN_mod_sqrt(result, a, b, ctx);
                    success = (sqrt_result != NULL);
                }
                break;
            case BN_OP_EXP:
                success = BN_exp(result, a, b, ctx);
                break;
            case BN_OP_MOD_EXP:
                success = BN_mod_exp(result, a, b, c, ctx);
                break;
            case BN_OP_GCD:
                success = BN_gcd(result, a, b, ctx);
                break;
            default:
                success = 0;
                break;
        }

        if (!success) {
            goto cleanup;
        }

        // 将结果转换为字符串
      
        res_str = BN_bn2dec(result);
        if (res_str) {
            set_string(&result_data[row], res_str);
            free(res_str);
        }

cleanup:
        if (a_str) duckdb_free(a_str);
        if (b_str) duckdb_free(b_str);
        if (c_str) duckdb_free(c_str);
        if (a) BN_free(a);
        if (b) BN_free(b);
        if (c) BN_free(c);
        if (result) BN_free(result);
    }

    BN_CTX_free(ctx);
}

// 函数配置信息
typedef struct {
    const char* name;
    bn_operation_t op_type;
    int num_params;
} bn_function_config;

static bn_function_config bn_functions[] = {
    {"bn_add", BN_OP_ADD, 2},
    {"bn_sub", BN_OP_SUB, 2},
    {"bn_mul", BN_OP_MUL, 2},
    {"bn_div", BN_OP_DIV, 2},
    {"bn_mod", BN_OP_MOD, 2},
    {"bn_sqr", BN_OP_SQR, 1},
    {"bn_sqrt", BN_OP_SQRT, 2},
    {"bn_exp", BN_OP_EXP, 2},
    {"bn_mod_exp", BN_OP_MOD_EXP, 3},
    {"bn_gcd", BN_OP_GCD, 2}
};

// 注册所有BN函数
void RegisterBNFunctions(duckdb_connection connection) {
    duckdb_logical_type varchar_type = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR);
    
    for (size_t i = 0; i < sizeof(bn_functions) / sizeof(bn_functions[0]); i++) {
        bn_function_config config = bn_functions[i];
        
        duckdb_scalar_function function = duckdb_create_scalar_function();
        duckdb_scalar_function_set_name(function, config.name);
        
        // 添加参数
        for (int j = 0; j < config.num_params; j++) {
            duckdb_scalar_function_add_parameter(function, varchar_type);
        }
        
        // 设置返回类型
        duckdb_scalar_function_set_return_type(function, varchar_type);
        
        // 设置操作类型作为bind_data
        // duckdb_function_info info = duckdb_scalar_function_get_function_info(function);
        // duckdb_function_info_set_bind_data(info, (void*)(uintptr_t)config.op_type);
        // 设置操作类型作为extra_info
        duckdb_scalar_function_set_extra_info(function, (void*)(uintptr_t)config.op_type, NULL);        
        // 设置函数实现
        duckdb_scalar_function_set_function(function, BNGenericFunction);
        
        // 注册函数
        duckdb_register_scalar_function(connection, function);
        duckdb_destroy_scalar_function(&function);
    }
    
    duckdb_destroy_logical_type(&varchar_type);
}

编译运行成功。

/par/cext# gcc -fPIC -shared -o libtest2.so *.c -I . -lssl -lcrypto
/par/cext# python3 ./appmeta.py -l libtest2.so -n add -dv v1.2.0  --duckdb-platform linux_arm64 --extension-version 0.1
Creating extension binary:
 - Input file: libtest2.so
 - Output file: add.duckdb_extension

/par/cext# /par/duckdb140 -unsigned
DuckDB v1.4.0 (Andium) b8a06e4a22
Enter ".help" for usage hints.
D load '/par/cext/add.duckdb_extension';
D select bn_mod('256','7');
┌────────────────────┐
│ bn_mod('256', '7') │
│      varchar       │
├────────────────────┤
│ 4                  │
└────────────────────┘
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值