openssl的大数运算库速度有点慢,想试用gmp来改写,告诉DeepSeek
bn库的实现可以了,请用gmp库来实现上述代码中同样的函数
输出如下
#include "duckdb_extension.h"
#include <gmp.h>
#include <string.h>
#include <stdlib.h>
DUCKDB_EXTENSION_EXTERN
// 支持的GMP操作类型
typedef enum {
GMP_OP_ADD, // 加法
GMP_OP_SUB, // 减法
GMP_OP_MUL, // 乘法
GMP_OP_DIV, // 除法(返回商)
GMP_OP_MOD, // 取模
GMP_OP_SQR, // 平方
GMP_OP_SQRT, // 平方根
GMP_OP_EXP, // 指数
GMP_OP_MOD_EXP, // 模指数
GMP_OP_GCD // 最大公约数
} gmp_operation_t;
// 从duckdb_string_t提取C字符串
static char* extract_string(duckdb_string_t str) {
char* result = NULL;
if (duckdb_string_is_inlined(str)) {
result = duckdb_malloc(str.value.inlined.length + 1);
memcpy(result, str.value.inlined.inlined, str.value.inlined.length);
result[str.value.inlined.length] = '\0';
} else {
result = duckdb_malloc(str.value.pointer.length + 1);
memcpy(result, str.value.pointer.ptr, str.value.pointer.length);
result[str.value.pointer.length] = '\0';
}
return result;
}
// 设置duckdb_string_t为C字符串
static void set_string(duckdb_string_t* str, const char* cstr) {
size_t len = strlen(cstr);
if (len <= 12) { // 内联字符串最大长度
str->value.inlined.length = len;
memcpy(str->value.inlined.inlined, cstr, len);
str->value.inlined.inlined[len] = '\0';
} else {
char* ptr = duckdb_malloc(len + 1);
memcpy(ptr, cstr, len);
ptr[len] = '\0';
str->value.pointer.ptr = ptr;
str->value.pointer.length = len;
}
}
// 通用的GMP操作函数
static void GMPGenericFunction(duckdb_function_info info, duckdb_data_chunk input, duckdb_vector output) {
gmp_operation_t op_type = (gmp_operation_t)(uintptr_t)duckdb_scalar_function_get_extra_info(info);
idx_t input_size = duckdb_data_chunk_get_size(input);
duckdb_vector a_vec = duckdb_data_chunk_get_vector(input, 0);
duckdb_string_t* a_data = (duckdb_string_t*)duckdb_vector_get_data(a_vec);
duckdb_string_t* result_data = (duckdb_string_t*)duckdb_vector_get_data(output);
// 获取有效性位图
uint64_t* a_validity = duckdb_vector_get_validity(a_vec);
uint64_t* b_validity = NULL;
uint64_t* c_validity = NULL;
duckdb_string_t* b_data = NULL;
duckdb_string_t* c_data = NULL;
// 根据操作类型获取不同的输入向量
if (op_type != GMP_OP_SQR && op_type != GMP_OP_SQRT) {
duckdb_vector b_vec = duckdb_data_chunk_get_vector(input, 1);
b_data = (duckdb_string_t*)duckdb_vector_get_data(b_vec);
b_validity = duckdb_vector_get_validity(b_vec);
}
if (op_type == GMP_OP_MOD_EXP) {
duckdb_vector c_vec = duckdb_data_chunk_get_vector(input, 2);
c_data = (duckdb_string_t*)duckdb_vector_get_data(c_vec);
c_validity = duckdb_vector_get_validity(c_vec);
}
for (idx_t row = 0; row < input_size; row++) {
// 检查NULL值
bool has_null = (a_validity && !duckdb_validity_row_is_valid(a_validity, row)) ||
(b_validity && !duckdb_validity_row_is_valid(b_validity, row)) ||
(c_validity && !duckdb_validity_row_is_valid(c_validity, row));
if (has_null) {
duckdb_vector_ensure_validity_writable(output);
uint64_t* result_validity = duckdb_vector_get_validity(output);
duckdb_validity_set_row_invalid(result_validity, row);
continue;
}
mpz_t a, b, c, result;
char* a_str = NULL;
char* b_str = NULL;
char* c_str = NULL;
char* res_str = NULL;
// 初始化GMP整数
mpz_init(a);
mpz_init(result);
if (b_data) mpz_init(b);
if (c_data) mpz_init(c);
// 提取输入字符串
a_str = extract_string(a_data[row]);
if (b_data) {
b_str = extract_string(b_data[row]);
}
if (c_data) {
c_str = extract_string(c_data[row]);
}
// 将字符串转换为GMP整数
if (mpz_set_str(a, a_str, 10) != 0 ||
(b_data && mpz_set_str(b, b_str, 10) != 0) ||
(c_data && mpz_set_str(c, c_str, 10) != 0)) {
goto cleanup;
}
// 根据操作类型执行相应的GMP函数
switch (op_type) {
case GMP_OP_ADD:
mpz_add(result, a, b);
break;
case GMP_OP_SUB:
mpz_sub(result, a, b);
break;
case GMP_OP_MUL:
mpz_mul(result, a, b);
break;
case GMP_OP_DIV:
mpz_tdiv_q(result, a, b); // 向零取整的除法
break;
case GMP_OP_MOD:
mpz_tdiv_r(result, a, b); // 取余数
break;
case GMP_OP_SQR:
mpz_mul(result, a, a); // 平方
break;
case GMP_OP_SQRT:
mpz_sqrt(result, a); // 平方根(向下取整)
break;
case GMP_OP_EXP:
if (mpz_fits_ulong_p(b)) {
unsigned long exp = mpz_get_ui(b);
mpz_pow_ui(result, a, exp);
} else {
// 指数太大,设为0
mpz_set_ui(result, 0);
}
break;
case GMP_OP_MOD_EXP:
if (mpz_fits_ulong_p(b)) {
unsigned long exp = mpz_get_ui(b);
mpz_powm_ui(result, a, exp, c);
} else {
// 使用大指数版本
mpz_powm(result, a, b, c);
}
break;
case GMP_OP_GCD:
mpz_gcd(result, a, b);
break;
default:
mpz_set_ui(result, 0);
break;
}
// 将结果转换为字符串
res_str = mpz_get_str(NULL, 10, result);
if (res_str) {
set_string(&result_data[row], res_str);
free(res_str);
}
cleanup:
if (a_str) duckdb_free(a_str);
if (b_str) duckdb_free(b_str);
if (c_str) duckdb_free(c_str);
// 清理GMP整数
mpz_clear(a);
mpz_clear(result);
if (b_data) mpz_clear(b);
if (c_data) mpz_clear(c);
}
}
// 函数配置信息
typedef struct {
const char* name;
gmp_operation_t op_type;
int num_params;
} gmp_function_config;
static gmp_function_config gmp_functions[] = {
{"gmp_add", GMP_OP_ADD, 2},
{"gmp_sub", GMP_OP_SUB, 2},
{"gmp_mul", GMP_OP_MUL, 2},
{"gmp_div", GMP_OP_DIV, 2},
{"gmp_mod", GMP_OP_MOD, 2},
{"gmp_sqr", GMP_OP_SQR, 1},
{"gmp_sqrt", GMP_OP_SQRT, 1},
{"gmp_exp", GMP_OP_EXP, 2},
{"gmp_mod_exp", GMP_OP_MOD_EXP, 3},
{"gmp_gcd", GMP_OP_GCD, 2}
};
// 注册所有GMP函数
void RegisterGMPFunctions(duckdb_connection connection) {
duckdb_logical_type varchar_type = duckdb_create_logical_type(DUCKDB_TYPE_VARCHAR);
for (size_t i = 0; i < sizeof(gmp_functions) / sizeof(gmp_functions[0]); i++) {
gmp_function_config config = gmp_functions[i];
duckdb_scalar_function function = duckdb_create_scalar_function();
duckdb_scalar_function_set_name(function, config.name);
// 添加参数
for (int j = 0; j < config.num_params; j++) {
duckdb_scalar_function_add_parameter(function, varchar_type);
}
// 设置返回类型
duckdb_scalar_function_set_return_type(function, varchar_type);
// 设置操作类型作为extra_info
duckdb_scalar_function_set_extra_info(function, (void*)(uintptr_t)config.op_type, NULL);
// 设置函数实现
duckdb_scalar_function_set_function(function, GMPGenericFunction);
// 注册函数
duckdb_register_scalar_function(connection, function);
duckdb_destroy_scalar_function(&function);
}
duckdb_destroy_logical_type(&varchar_type);
}
将它保存为gmp.c然后放入bn.c的相同目录,并在capi_quack.c和add_numbers.h中添加RegisterGMPFunctions的调用和声明
编译执行如下,注意需要添加-lgmp选项动态链接gmp库。
root@6ae32a5ffcde:/par/cext# gcc -fPIC -shared -o libtest2.so *.c -I . -lssl -lcrypto -lgmp
root@6ae32a5ffcde:/par/cext# python3 ./appmeta.py -l libtest2.so -n add -dv v1.2.0 --duckdb-platform linux_amd64 --extension-version 0.1
Creating extension binary:
- Input file: libtest2.so
- Output file: add.duckdb_extension
root@6ae32a5ffcde:/par/cext# /par/duckdb140 -unsigned
DuckDB v1.4.0 (Andium) b8a06e4a22
Enter ".help" for usage hints.
D load '/par/cext/add.duckdb_extension';
D select gmp_add('10','999');
┌──────────────────────┐
│ gmp_add('10', '999') │
│ varchar │
├──────────────────────┤
│ 1009 │
└──────────────────────┘
D .timer on
D with recursive t as(select '1' a,'1' s union all select (a::int+1)::varchar x, bn_mul(s,x) from t where a::int <1000) from t where a='1000';
┌─────────┬────────────────────────────────────────────────────────────────────────────────────────────────────────────┐
│ a │ s │
│ varchar │ varchar │
├─────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 1000 │ 40238726007709377354370243392300398571937486421071463254379991042993851239862902059204420848696940480047… │
└─────────┴────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
Run Time (s): real 0.332 user 0.232357 sys 0.257653
D with recursive t as(select '1' a,'1' s union all select (a::int+1)::varchar x, gmp_mul(s,x) from t where a::int <1000) from t where a='1000';
┌─────────┬────────────────────────────────────────────────────────────────────────────────────────────────────────────┐
│ a │ s │
│ varchar │ varchar │
├─────────┼────────────────────────────────────────────────────────────────────────────────────────────────────────────┤
│ 1000 │ 40238726007709377354370243392300398571937486421071463254379991042993851239862902059204420848696940480047… │
└─────────┴────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
Run Time (s): real 0.194 user 0.180921 sys 0.148959
D select gmp_mod('256','7');
┌─────────────────────┐
│ gmp_mod('256', '7') │
│ varchar │
├─────────────────────┤
│ 4 │
└─────────────────────┘
Run Time (s): real 0.003 user 0.002576 sys 0.000000
测试计算1000的阶乘,确实gmp库更快, 计算mod也更方便,不用每次都调用BN_new(),大概这也是快的一个原因。

6114

被折叠的 条评论
为什么被折叠?



