DeepSeek辅助采用gmp库改写DuckDB大数运算插件

openssl的大数运算库速度有点慢,想试用gmp来改写,告诉DeepSeek

bn库的实现可以了,请用gmp库来实现上述代码中同样的函数

输出如下

#include "duckdb_extension.h"
#include <gmp.h>
#include <string.h>
#include <stdlib.h>
DUCKDB_EXTENSION_EXTERN

// 支持的GMP操作类型
typedef enum {
    GMP_OP_ADD,      // 加法
    GMP_OP_SUB,      // 减法
    GMP_OP_MUL,      // 乘法
    GMP_OP_DIV,      // 除法(返回商)
    GMP_OP_MOD,      // 取模
    GMP_OP_SQR,      // 平方
    GMP_OP_SQRT,     // 平方根
    GMP_OP_EXP,      // 指数
    GMP_OP_MOD_EXP,  // 模指数
    GMP_OP_GCD       // 最大公约数
} gmp_operation_t;

// 从duckdb_string_t提取C字符串
static char* extract_string(duckdb_string_t str) {
    char* result = NULL;
    if (duckdb_string_is_inlined(str)) {
        result = duckdb_malloc(str.value.inlined.length + 1);
        memcpy(result, str.value.inlined.inlined, str.value.inlined.length);
        result[str.value.inlined.length] = '\0';
    } else {
        result = duckdb_malloc(str.value.pointer.length + 1);
        memcpy(result, str.value.pointer.ptr, str.value.pointer.length);
        result[str.value.pointer.length] = '\0';
    }
    return result;
}

// 设置duckdb_string_t为C字符串
static void set_string(duckdb_string_t* str, const char* cstr) {
    size_t len = strlen(cstr);
    if (len <= 12) { // 内联字符串最大长度
        str->value.inlined.length = len;
        memcpy(str->value.inlined.inlined, cstr, len);
        str->value.inlined.inlined[len] = '\0';
    } else {
        char* ptr = duckdb_malloc(len + 1);
        memcpy(ptr, cstr, len);
        ptr[len] = '\0';
        str->value.pointer.ptr = ptr;
        str->value.pointer.length = len;
    }
}

// 通用的GMP操作函数
static void GMPGenericFunction(duckdb_function_info info, duckdb_data_chunk input, duckdb_vector output) {
    gmp_operation_t op_type = (gmp_operation_t)(uintptr_t)duckdb_scalar_function_get_extra_info(info);
    
    idx_t input_size = duckdb_data_chunk_get_size(input);
    duckdb_vector a_vec = duckdb_data_chunk_get_vector(input, 0);
    duckdb_string_t* a_data = (duckdb_string_t*)duckdb_vector_get_data(a_vec);
    duckdb_string_t* result_data = (duckdb_string_t*)duckdb_vector_get_data(output);
    
    // 获取有效性位图
    uint64_t* a_validity = duckdb_vector_get_validity(a_vec);
    uint64_t* b_validity = NULL;
    uint64_t* c_validity = NULL;
    
    duckdb_string_t* b_data = NULL;
    duckdb_string_t* c_data = NULL;
    
    // 根据操作类型获取不同的输入向量
    if (op_type != GMP_OP_SQR && op_type != GMP_OP_SQRT) {
        duckdb_vector b_vec = duckdb_data_chunk_get_vector(input, 1);
        b_data = (duckdb_string_t*)duckdb_vector_get_data(b_vec);
        b_validity = duckdb_vector_get_validity(b_vec);
    }
    
    if (op_type == GMP_OP_MOD_EXP) {
        duckdb_vector c_vec = duckdb_data_chunk_get_vector(input, 2);
        c_data = (duckdb_string_t*)duckdb_vector_get_data(c_vec);
        c_validity = duckdb_vector_get_validity(c_vec);
    }

    for (idx_t row = 0; row < input_size; row++) {
        // 检查NULL值
        bool has_null = (a_validity && !duckdb_validity_row_is_valid(a_validity, row)) ||
                       (b_validity && !duckdb_validity_row_is_valid(b_validity, row)) ||
                       (c_validity && !duckdb_validity_row_is_valid(c_validity, row));
        
        if (has_null) {
            duckdb_vector_ensure_validity_writable(output);
            uint64_t* result_validity = duckdb_vector_get_validity(output);
            duckdb_validity_set_row_invalid(result_validity, row);
            continue;
        }

        mpz_t a, b, c, result;
        char* a_str = NULL;
        char* b_str = NULL;
        char* c_str = NULL;
        char* res_str = NULL;

        // 初始化GMP整数
        mpz_init(a);
        mpz_init(result);
        if (b_data) mpz_init(b);
        if (c_data) mpz_init(c);

        // 提取输入字符串
        a_str = extract_string(a_data[row]);
        if (b_data) {
            b_str = extract_string(b_data[row]);
        }
        if (c_data) {
            c_str = extract_string(c_data[row]);
        }

        // 将字符串转换为GMP整数
        if (mpz_set_str(a, a_str, 10) != 0 ||
            (b_data && mpz_set_str(b, b_str, 10) != 0) ||
            (c_data && mpz_set_str(c, c_str, 10) != 0)) {
            goto cleanup;
        }

        // 根据操作类型执行相应的GMP函数
        switch (op_type) {
            case GMP_OP_ADD:
                mpz_add(result, a, b);
                break;
            case GMP_OP_SUB:
                mpz_sub(result, a, b);
                break;
            case GMP_OP_MUL:
                mpz_mul(result, a, b);
                break;
            case GMP_OP_DIV:
                mpz_tdiv_q(result, a, b);  // 向零取整的除法
                break;
            case GMP_OP_MOD:
                mpz_tdiv_r(result, a, b);  // 取余数
                break;
            case GMP_OP_SQR:
                mpz_mul(result, a, a);     // 平方
                break;
            case GMP_OP_SQRT:
                mpz_sqrt(result, a);       // 平方根(向下取整)
                break;
            case GMP_OP_EXP:
                if (mpz_fits_ulong_p(b)) {
                    unsigned long exp = mpz_get_ui(b);
                    mpz_pow_ui(result, a, exp);
                } else {
                    // 指数太大,设为0
                    mpz_set_ui(result, 0);
                }
                break;
            case GMP_OP_MOD_EXP:
                if (mpz_fits_ulong_p(b)) {
                    unsigned long exp = mpz_get_ui(b);
                    mpz_powm_ui(result, a, exp, c);
                } else {
                    // 使用大指数版本
                    mpz_powm(result, a, b, c);
                }
                break;
            case GMP_OP_GCD:
                mpz_gcd(result, a, b);
                break;
            default:
                mpz_set_ui(result, 0);
                break;
        }

        // 将结果转换为字符串
        res_str = mpz_get_str(NULL, 10, result);
        if (res_str) {
            set_string(&result_data[row], res_str);
            free(res_str);
        }

cleanup:
        if (a_str) duckdb_free(a_str);
        if (b_str) duckdb_free(b_str);
        if (c_str) duckdb_free(c_str);
        
        // 清理GMP整数
        mpz_clear(a);
        mpz_clear(result);
        if (b_data) mpz_clear(b);
        if (c_data) mpz_clear(c);
    }
}

// 函数配置信息
typedef struct {
    const char* name;
    gmp_operation_t op_type;
    int num_params;
} gmp_function_config;

static gmp_function_config gmp_functions[] = {
    {"gmp_add", GMP_OP_ADD, 2},
    {"gmp_sub", GMP_OP_SUB, 2},
    {"gmp_mul", GMP_OP_MUL, 2},
    {"gmp_div", GMP_OP_DIV, 2},
    {"gmp_mod", GMP_OP_MOD, 2},
    {"gmp_sqr", GMP_OP_SQR, 1},
    {"gmp_sqrt", GMP_OP_SQRT, 1},
    {"gmp_exp", GMP_OP_EXP, 2},
    {"gmp_mod_exp", GMP_OP_MOD_EXP, 3},
    {"gmp_gcd", GMP_OP_GCD, 2}
};

// 注册所有GMP函数
void RegisterGMPFunctions(duckdb_connection connection) {
    duckdb_logical_type varchar_type = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR);
    
    for (size_t i = 0; i < sizeof(gmp_functions) / sizeof(gmp_functions[0]); i++) {
        gmp_function_config config = gmp_functions[i];
        
        duckdb_scalar_function function = duckdb_create_scalar_function();
        duckdb_scalar_function_set_name(function, config.name);
        
        // 添加参数
        for (int j = 0; j < config.num_params; j++) {
            duckdb_scalar_function_add_parameter(function, varchar_type);
        }
        
        // 设置返回类型
        duckdb_scalar_function_set_return_type(function, varchar_type);
        
        // 设置操作类型作为extra_info
        duckdb_scalar_function_set_extra_info(function, (void*)(uintptr_t)config.op_type, NULL);
        
        // 设置函数实现
        duckdb_scalar_function_set_function(function, GMPGenericFunction);
        
        // 注册函数
        duckdb_register_scalar_function(connection, function);
        duckdb_destroy_scalar_function(&function);
    }
    
    duckdb_destroy_logical_type(&varchar_type);
}

将它保存为gmp.c然后放入bn.c的相同目录,并在capi_quack.c和add_numbers.h中添加RegisterGMPFunctions的调用和声明
编译执行如下,注意需要添加-lgmp选项动态链接gmp库。

root@6ae32a5ffcde:/par/cext# gcc -fPIC -shared -o libtest2.so *.c -I . -lssl -lcrypto -lgmp
root@6ae32a5ffcde:/par/cext# python3 ./appmeta.py -l libtest2.so -n add -dv v1.2.0  --duckdb-platform linux_amd64 --extension-version 0.1
Creating extension binary:
 - Input file: libtest2.so
 - Output file: add.duckdb_extension

root@6ae32a5ffcde:/par/cext# /par/duckdb140 -unsigned
DuckDB v1.4.0 (Andium) b8a06e4a22
Enter ".help" for usage hints.
D load '/par/cext/add.duckdb_extension';
D select gmp_add('10','999');
┌──────────────────────┐
│ gmp_add('10', '999') │
│       varchar        │
├──────────────────────┤
│ 1009                 │
└──────────────────────┘
D .timer on
D with recursive t as(select '1' a,'1' s  union all select (a::int+1)::varchar x, bn_mul(s,x) from t where a::int <1000) from t where a='1000';
┌─────────┬────────────────────────────────────────────────────────────────────────────────────────────────────────────┐
│    a    │                                                     s                                                      │
│ varchar │                                                  varchar                                                   │
├─────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 100040238726007709377354370243392300398571937486421071463254379991042993851239862902059204420848696940480047…  │
└─────────┴────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
Run Time (s): real 0.332 user 0.232357 sys 0.257653
D with recursive t as(select '1' a,'1' s  union all select (a::int+1)::varchar x, gmp_mul(s,x) from t where a::int <1000) from t where a='1000';
┌─────────┬────────────────────────────────────────────────────────────────────────────────────────────────────────────┐
│    a    │                                                     s                                                      │
│ varchar │                                                  varchar                                                   │
├─────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 100040238726007709377354370243392300398571937486421071463254379991042993851239862902059204420848696940480047…  │
└─────────┴────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
Run Time (s): real 0.194 user 0.180921 sys 0.148959
D select gmp_mod('256','7');
┌─────────────────────┐
│ gmp_mod('256', '7') │
│       varchar       │
├─────────────────────┤
│ 4                   │
└─────────────────────┘
Run Time (s): real 0.003 user 0.002576 sys 0.000000

测试计算1000的阶乘,确实gmp库更快, 计算mod也更方便,不用每次都调用BN_new(),大概这也是快的一个原因。

#include <iostream> #include <vector> #include <string> using namespace std; // 模幂运算:使用 Montgomery's Ladder 实现常数时间幂运算 // 计算 (base^exp) % mod,保证执行路径不依赖 exp 的比特值 uint64_t constant_time_mod_exp(uint64_t base, uint64_t exp, uint64_t mod) { uint64_t r0 = 1; uint64_t r1 = base % mod; // 假设使用 64 位指数,从高位开始遍历 for (int i = 63; i >= 0; --i) { uint64_t bit = (exp >> i) & 1; // 统一操作:避免分支预测泄露 uint64_t swap = bit; uint64_t temp_r0 = r0; r0 = (swap * r1 + (1 - swap) * r0) % mod; // 条件选择 r1 if bit==1 r1 = (swap * temp_r0 + (1 - swap) * r1) % mod; // 条件交换 // 平方步骤保持一致 uint64_t r0_sq = (r0 * r0) % mod; uint64_t r1_mul = (r0 * r1) % mod; r0 = r0_sq; r1 = r1_mul; } return r0; } // 扩展欧几里得算法:求 d ≡ e⁻¹ mod φ(n) bool mod_inverse(uint64_t e, uint64_t phi, uint64_t& d) { int64_t t = 0, newt = 1; int64_t r = phi, newr = e; while (newr != 0) { uint64_t quotient = r / newr; swap(t, newt); newt = t - quotient * newt; swap(r, newr); newr = r - quotient * newr; } if (r > 1) return false; // 无逆元 if (t < 0) t += phi; d = t; return true; } // 简化版 RSA 密钥生成(仅用于演示,实际应使用大数) void rsa_keygen(uint64_t p, uint64_t q, uint64_t e, uint64_t& n, uint64_t& d) { n = p * q; uint64_t phi = (p - 1) * (q - 1); bool success = mod_inverse(e, phi, d); if (!success) { cerr << "Error: Inverse does not exist!" << endl; exit(1); } } // RSA 加密:ciphertext = plaintext^e mod n uint64_t rsa_encrypt(uint64_t plaintext, uint64_t e, uint64_t n) { return constant_time_mod_exp(plaintext, e, n); } // RSA 解密:plaintext = ciphertext^d mod n(使用常数时间模幂) uint64_t rsa_decrypt(uint64_t ciphertext, uint64_t d, uint64_t n) { return constant_time_mod_exp(ciphertext, d, n); } // 测试示例 int main() { uint64_t p = 61, q = 53; // 小素数示例 uint64_t e = 17; // 公钥指数 uint64_t n, d; // 生成密钥 rsa_keygen(p, q, e, n, d); cout << "Public Key: (e=" << e << ", n=" << n << ")" << endl; cout << "Private Key: (d=" << d << ", n=" << n << ")" << endl; uint64_t plaintext = 65; cout << "Plaintext: " << plaintext << endl; // 加密 uint64_t ciphertext = rsa_encrypt(plaintext, e, n); cout << "Ciphertext: " << ciphertext << endl; // 解密 uint64_t decrypted = rsa_decrypt(ciphertext, d, n); cout << "Decrypted: " << decrypted << endl; return 0; } 帮我在这个代码的基础上使用大数
09-21
题目背景 材料 1: 请小心地计算下面的算式:138−108÷6=? 你大概难以置信,这个算式的计算结果竟然是 5! 材料 2: 对于一个正整数 𝑥,𝑥!=1×2×⋯×(𝑥−1)×𝑥。我们称 𝑥! 为 𝑥 的阶乘。 特别的,0!=1。 显然,「138−108÷6=5」是错误的,而「(138−108)÷6=5」是正确的,所以对材料 1 中的内容,部分读者会认为「作者没有搞清加减乘除的运算优先级关系而犯错」。 然而,材料 1 最后一行的叹号并不是标点符号,而是材料 2 提到的「阶乘」。 考虑到这一点,「138−108÷6=5!=1×2×⋯×5=120」显然就是正确的了。 题目描述 有关「上述等式为何正确」的问题解决了,然而「如何构造出上述那种让人啼笑皆非的正确等式」成为了一个新的问题。 我们认为这个问题太难了,因此我们把解决这个问题的任务交给了你,相信你可以完成这个任务。 我们会给你一个整数 𝑛,请你帮助求出一组整数 𝑥,𝑦,𝑧,满足 𝑥−𝑦÷𝑧=𝑛! 且 (𝑥−𝑦)÷𝑧=𝑛。 实际上可以发现,当 𝑧=2 时,原式变为 { 𝑥−𝑦2=𝑛! 𝑥2−𝑦2=𝑛 ,这时,只需要让 𝑥=2×(𝑛!−𝑛),并根据任何一个式子计算出 𝑦 的值(为 2×(𝑛!−2𝑛)),即可构成一组合法答案。这样的答案是总是存在的。 因此,按照我们给出的这种方式直接输出 2×(𝑛!−𝑛)、2×(𝑛!−2𝑛)、2 即可通过本题,难点便来到了计算出对应的值上。 当然,你也可以使用其他方法计算出符合要求的 𝑥,𝑦,𝑧。 输入格式 输入共一行一个整数 𝑛。 输出格式 输出共一行三个整数 𝑥,𝑦,𝑧,代表满足 𝑥−𝑦÷𝑧=𝑛! 且 (𝑥−𝑦)÷𝑧=𝑛 的一组整数。 三者两两之间以一个空格隔开。
05-12
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值