利用DeepSeek用C语言实现DuckDB内置repeat函数插件

DuckDB内置repeat函数的c++源代码如下,因为用到了模板,非常简洁:

namespace duckdb {

struct RepeatFunctionData : public TableFunctionData {
	RepeatFunctionData(Value value, idx_t target_count) : value(std::move(value)), target_count(target_count) {
	}

	Value value;
	idx_t target_count;
};

struct RepeatOperatorData : public GlobalTableFunctionState {
	RepeatOperatorData() : current_count(0) {
	}
	idx_t current_count;
};

static unique_ptr<FunctionData> RepeatBind(ClientContext &context, TableFunctionBindInput &input,
                                           vector<LogicalType> &return_types, vector<string> &names) {
	// the repeat function returns the type of the first argument
	auto &inputs = input.inputs;
	return_types.push_back(inputs[0].type());
	names.push_back(inputs[0].ToString());
	if (inputs[1].IsNull()) {
		throw BinderException("Repeat second parameter cannot be NULL");
	}
	auto repeat_count = inputs[1].GetValue<int64_t>();
	if (repeat_count < 0) {
		throw BinderException("Repeat second parameter cannot be be less than 0");
	}
	return make_uniq<RepeatFunctionData>(inputs[0], NumericCast<idx_t>(repeat_count));
}

static unique_ptr<GlobalTableFunctionState> RepeatInit(ClientContext &context, TableFunctionInitInput &input) {
	return make_uniq<RepeatOperatorData>();
}

static void RepeatFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {
	auto &bind_data = data_p.bind_data->Cast<RepeatFunctionData>();
	auto &state = data_p.global_state->Cast<RepeatOperatorData>();

	idx_t remaining = MinValue<idx_t>(bind_data.target_count - state.current_count, STANDARD_VECTOR_SIZE);
	output.data[0].Reference(bind_data.value);
	output.SetCardinality(remaining);
	state.current_count += remaining;
}

static unique_ptr<NodeStatistics> RepeatCardinality(ClientContext &context, const FunctionData *bind_data_p) {
	auto &bind_data = bind_data_p->Cast<RepeatFunctionData>();
	return make_uniq<NodeStatistics>(bind_data.target_count, bind_data.target_count);
}

void RepeatTableFunction::RegisterFunction(BuiltinFunctions &set) {
	TableFunction repeat("repeat", {LogicalType::ANY, LogicalType::BIGINT}, RepeatFunction, RepeatBind, RepeatInit);
	repeat.cardinality = RepeatCardinality;
	set.AddFunction(repeat);
}

} // namespace duckdb

让DeepSeek按照上述函数实现的C版本就比较冗长,开始实现的一版,在bind函数中需要复制数据,复制之前要读取类型,用一个大型switch语句调用不同的函数,然后在Function函数中要重复n次数据,又要经历这个过程,每次增加类型,需要前后跳着修改,下面是经过我要求优化后的版本,把各种复制数据和重复n次数据的函数单独实现,然后用循环查找类型,调用函数指针,虽然总长度差不多,但修改更集中一点。
他的实现保存duckdb_logical_type,要申请内存还要注意销毁,我改成duckdb_type,其实就是枚举的整数,避免了这些麻烦。
值得一提的是他一开始在代码中引用了一个不存在的duckdb_value_copy函数。我直接删除了函数,改为对value赋值,但这个value其实是指针,所指的局部对象被销毁就取不到值了,浪费了很多时间排查。

#define DUCKDB_EXTENSION_API_VERSION_UNSTABLE 1
#include "duckdb_extension.h"
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
DUCKDB_EXTENSION_EXTERN

#define STANDARD_VECTOR_SIZE 2048

// 表函数数据结构
typedef struct {
    duckdb_type type_id;
    void* value_data;
    idx_t target_count;
} RepeatFunctionData;

typedef struct {
    idx_t current_count;
} RepeatOperatorData;

// 类型处理函数指针
typedef void (*ValueExtractor)(duckdb_value, void*);
typedef void (*ValueRepeater)(void*, duckdb_vector, idx_t);
typedef void (*ValueDestroyer)(void*);

// 类型处理信息
typedef struct {
    duckdb_type type_id;
    size_t data_size;
    ValueExtractor extract;
    ValueRepeater repeat;
    ValueDestroyer destroy;
} TypeHandler;

// 布尔值处理
static void ExtractBoolean(duckdb_value value, void* data) {
    *(bool*)data = duckdb_get_bool(value);
}

static void RepeatBoolean(void* data, duckdb_vector vector, idx_t count) {
    bool value = *(bool*)data;
    bool* vec_data = (bool*)duckdb_vector_get_data(vector);
    for (idx_t i = 0; i < count; i++) {
        vec_data[i] = value;
    }
}

// 整数类型处理
static void ExtractInt8(duckdb_value value, void* data) {
    *(int8_t*)data = duckdb_get_int8(value);
}

static void RepeatInt8(void* data, duckdb_vector vector, idx_t count) {
    int8_t value = *(int8_t*)data;
    int8_t* vec_data = (int8_t*)duckdb_vector_get_data(vector);
    for (idx_t i = 0; i < count; i++) {
        vec_data[i] = value;
    }
}

static void ExtractInt16(duckdb_value value, void* data) {
    *(int16_t*)data = duckdb_get_int16(value);
}

static void RepeatInt16(void* data, duckdb_vector vector, idx_t count) {
    int16_t value = *(int16_t*)data;
    int16_t* vec_data = (int16_t*)duckdb_vector_get_data(vector);
    for (idx_t i = 0; i < count; i++) {
        vec_data[i] = value;
    }
}

static void ExtractInt32(duckdb_value value, void* data) {
    *(int32_t*)data = duckdb_get_int32(value);
}

static void RepeatInt32(void* data, duckdb_vector vector, idx_t count) {
    int32_t value = *(int32_t*)data;
    int32_t* vec_data = (int32_t*)duckdb_vector_get_data(vector);
    for (idx_t i = 0; i < count; i++) {
        vec_data[i] = value;
    }
}

static void ExtractInt64(duckdb_value value, void* data) {
    *(int64_t*)data = duckdb_get_int64(value);
}

static void RepeatInt64(void* data, duckdb_vector vector, idx_t count) {
    int64_t value = *(int64_t*)data;
    int64_t* vec_data = (int64_t*)duckdb_vector_get_data(vector);
    for (idx_t i = 0; i < count; i++) {
        vec_data[i] = value;
    }
}

// 浮点数处理
static void ExtractFloat(duckdb_value value, void* data) {
    *(float*)data = duckdb_get_float(value);
}

static void RepeatFloat(void* data, duckdb_vector vector, idx_t count) {
    float value = *(float*)data;
    float* vec_data = (float*)duckdb_vector_get_data(vector);
    for (idx_t i = 0; i < count; i++) {
        vec_data[i] = value;
    }
}

static void ExtractDouble(duckdb_value value, void* data) {
    *(double*)data = duckdb_get_double(value);
}

static void RepeatDouble(void* data, duckdb_vector vector, idx_t count) {
    double value = *(double*)data;
    double* vec_data = (double*)duckdb_vector_get_data(vector);
    for (idx_t i = 0; i < count; i++) {
        vec_data[i] = value;
    }
}

// 字符串处理
static void ExtractVarchar(duckdb_value value, void* data) {
    const char* str = duckdb_get_varchar(value);
    char** str_ptr = (char**)data;
    *str_ptr = duckdb_malloc(strlen(str) + 1);
    strcpy(*str_ptr, str);
}

static void RepeatVarchar(void* data, duckdb_vector vector, idx_t count) {
    const char* value = *(char**)data;
    duckdb_string_t* vec_data = (duckdb_string_t*)duckdb_vector_get_data(vector);
    size_t len = strlen(value);
    
    for (idx_t i = 0; i < count; i++) {
        if (len <= 12) {
            vec_data[i].value.inlined.length = len;
            memcpy(vec_data[i].value.inlined.inlined, value, len);
        } else {
            vec_data[i].value.pointer.length = len;
            vec_data[i].value.pointer.ptr = duckdb_malloc(len + 1);
            memcpy(vec_data[i].value.pointer.ptr, value, len + 1);
        }
    }
}

static void DestroyVarchar(void* data) {
    char* str = *(char**)data;
    if (str) duckdb_free(str);
}

// 日期时间处理
static void ExtractDate(duckdb_value value, void* data) {
    *(duckdb_date*)data = duckdb_get_date(value);
}

static void RepeatDate(void* data, duckdb_vector vector, idx_t count) {
    duckdb_date value = *(duckdb_date*)data;
    duckdb_date* vec_data = (duckdb_date*)duckdb_vector_get_data(vector);
    for (idx_t i = 0; i < count; i++) {
        vec_data[i] = value;
    }
}

static void ExtractTime(duckdb_value value, void* data) {
    *(duckdb_time*)data = duckdb_get_time(value);
}

static void RepeatTime(void* data, duckdb_vector vector, idx_t count) {
    duckdb_time value = *(duckdb_time*)data;
    duckdb_time* vec_data = (duckdb_time*)duckdb_vector_get_data(vector);
    for (idx_t i = 0; i < count; i++) {
        vec_data[i] = value;
    }
}

static void ExtractTimestamp(duckdb_value value, void* data) {
    *(duckdb_timestamp*)data = duckdb_get_timestamp(value);
}

static void RepeatTimestamp(void* data, duckdb_vector vector, idx_t count) {
    duckdb_timestamp value = *(duckdb_timestamp*)data;
    duckdb_timestamp* vec_data = (duckdb_timestamp*)duckdb_vector_get_data(vector);
    for (idx_t i = 0; i < count; i++) {
        vec_data[i] = value;
    }
}

// 类型处理器数组
static TypeHandler type_handlers[] = {
    {DUCKDB_TYPE_BOOLEAN, sizeof(bool), ExtractBoolean, RepeatBoolean, NULL},
    {DUCKDB_TYPE_TINYINT, sizeof(int8_t), ExtractInt8, RepeatInt8, NULL},
    {DUCKDB_TYPE_SMALLINT, sizeof(int16_t), ExtractInt16, RepeatInt16, NULL},
    {DUCKDB_TYPE_INTEGER, sizeof(int32_t), ExtractInt32, RepeatInt32, NULL},
    {DUCKDB_TYPE_BIGINT, sizeof(int64_t), ExtractInt64, RepeatInt64, NULL},
    {DUCKDB_TYPE_FLOAT, sizeof(float), ExtractFloat, RepeatFloat, NULL},
    {DUCKDB_TYPE_DOUBLE, sizeof(double), ExtractDouble, RepeatDouble, NULL},
    {DUCKDB_TYPE_VARCHAR, sizeof(char*), ExtractVarchar, RepeatVarchar, DestroyVarchar},
    {DUCKDB_TYPE_DATE, sizeof(duckdb_date), ExtractDate, RepeatDate, NULL},
    {DUCKDB_TYPE_TIME, sizeof(duckdb_time), ExtractTime, RepeatTime, NULL},
    {DUCKDB_TYPE_TIMESTAMP, sizeof(duckdb_timestamp), ExtractTimestamp, RepeatTimestamp, NULL}
};

// 查找类型处理器
static TypeHandler* FindTypeHandler(duckdb_type type_id) {
    for (size_t i = 0; i < sizeof(type_handlers) / sizeof(type_handlers[0]); i++) {
        if (type_handlers[i].type_id == type_id) {
            return &type_handlers[i];
        }
    }
    return NULL;
}

// Bind函数
static void RepeatBind(duckdb_bind_info info) {
    idx_t param_count = duckdb_bind_get_parameter_count(info);
    if (param_count != 2) {
        duckdb_bind_set_error(info, "Repeat function requires exactly 2 parameters");
        return;
    }
    
    duckdb_value value_param = duckdb_bind_get_parameter(info, 0);
    if (duckdb_is_null_value(value_param)) {
        duckdb_bind_set_error(info, "Repeat first parameter cannot be NULL");
        return;
    }

    duckdb_value count_param = duckdb_bind_get_parameter(info, 1);
    if (duckdb_is_null_value(count_param)) {
        duckdb_bind_set_error(info, "Repeat second parameter cannot be NULL");
        return;
    }
    
    int64_t repeat_count = duckdb_get_int64(count_param);
    if (repeat_count < 0) {
        duckdb_bind_set_error(info, "Repeat second parameter cannot be less than 0");
        return;
    }
    
    duckdb_logical_type value_type = duckdb_get_value_type(value_param);
    duckdb_type type_id = duckdb_get_type_id(value_type);

    TypeHandler* handler = FindTypeHandler(type_id);
    if (!handler) {
        duckdb_bind_set_error(info, "Unsupported data type for repeat bind function");
        duckdb_destroy_logical_type(&value_type);
        return;
    }
    
    duckdb_bind_add_result_column(info, "result", value_type);
    
    RepeatFunctionData* bind_data = duckdb_malloc(sizeof(RepeatFunctionData));
    bind_data->type_id = type_id;  
    bind_data->value_data = duckdb_malloc(handler->data_size);
    bind_data->target_count = (idx_t)repeat_count;
    
    handler->extract(value_param, bind_data->value_data);
    
    duckdb_bind_set_bind_data(info, bind_data, duckdb_free);
    duckdb_bind_set_cardinality(info, bind_data->target_count, true);
    duckdb_destroy_logical_type(&value_type);
}

// Init函数
static void RepeatInit(duckdb_init_info info) {
    RepeatFunctionData* bind_data = duckdb_init_get_bind_data(info);
    RepeatOperatorData* init_data = duckdb_malloc(sizeof(RepeatOperatorData));
    init_data->current_count = 0;
    duckdb_init_set_init_data(info, init_data, duckdb_free);
}

// 表函数实现
static void RepeatFunction(duckdb_function_info info, duckdb_data_chunk output) {
    RepeatFunctionData* bind_data = duckdb_function_get_bind_data(info);
    RepeatOperatorData* init_data = duckdb_function_get_init_data(info);
    
    if (!bind_data || !init_data) {
        duckdb_function_set_error(info, "Invalid bind or init data");
        return;
    }
    
    idx_t remaining = bind_data->target_count - init_data->current_count;
    if (remaining == 0) {
        duckdb_data_chunk_set_size(output, 0);
        return;
    }
    
    idx_t output_size = remaining < STANDARD_VECTOR_SIZE ? remaining : STANDARD_VECTOR_SIZE;
    duckdb_vector output_vector = duckdb_data_chunk_get_vector(output, 0);
    
    duckdb_type type_id = (bind_data->type_id);
    TypeHandler* handler = FindTypeHandler(type_id);

    if (!handler) {
        duckdb_function_set_error(info, "Unsupported data type for repeat function");
        return;
    }
    
    handler->repeat(bind_data->value_data, output_vector, output_size);
    
    duckdb_vector_ensure_validity_writable(output_vector);
    uint64_t* validity = duckdb_vector_get_validity(output_vector);
    for (idx_t i = 0; i < output_size; i++) {
        duckdb_validity_set_row_valid(validity, i);
    }
    
    duckdb_data_chunk_set_size(output, output_size);
    init_data->current_count += output_size;
}

// 销毁函数
static void RepeatBindDataDestroy(void* data) {
    RepeatFunctionData* bind_data = (RepeatFunctionData*)data;
    if (bind_data->value_data) {
        duckdb_type type_id = (bind_data->type_id);
        TypeHandler* handler = FindTypeHandler(type_id);
        if (handler && handler->destroy) {
            handler->destroy(bind_data->value_data);
        }
        duckdb_free(bind_data->value_data);
    }

    duckdb_free(bind_data);
}

// 注册repeat表函数
void RegisterRepeatTableFunction(duckdb_connection connection) {
    duckdb_table_function table_func = duckdb_create_table_function();
    duckdb_table_function_set_name(table_func, "myrepeat");
    
    duckdb_logical_type any_type = duckdb_create_logical_type(DUCKDB_TYPE_ANY);
    duckdb_logical_type bigint_type = duckdb_create_logical_type(DUCKDB_TYPE_BIGINT);
    
    duckdb_table_function_add_parameter(table_func, any_type);
    duckdb_table_function_add_parameter(table_func, bigint_type);
    
    duckdb_table_function_set_bind(table_func, RepeatBind);
    duckdb_table_function_set_init(table_func, RepeatInit);
    duckdb_table_function_set_function(table_func, RepeatFunction);
    
    duckdb_register_table_function(connection, table_func);
    
    duckdb_destroy_table_function(&table_func);
    duckdb_destroy_logical_type(&any_type);
    duckdb_destroy_logical_type(&bigint_type);
}

将上述代码保存为myrepeat.c,将RegisterRepeatTableFunction的声明和调用加入capi_quack.c,
以下是编译插件和调用函数,注意未用强制转换的小数是Decimal类型,我没有实现,列表类型等复合类型也没有实现。

/par/cext# gcc -fPIC -shared -o libtest2.so *.c -I . -lssl -lcrypto -lgmp
/par/cext# python3 ./appmeta.py -l libtest2.so -n add -dv v1.2.0  --duckdb-platform linux_amd64 --extension-version 0.1 >/dev/null
/par/cext# /par/duckdb140 -unsigned
DuckDB v1.4.0 (Andium) b8a06e4a22
Enter ".help" for usage hints.
D load '/par/cext/add.duckdb_extension';
D select count(*) from myrepeat(2::varchar,3000);
┌──────────────┐
│ count_star() │
│    int64     │
├──────────────┤
│     3000     │
└──────────────┘
D select count(*) from myrepeat(1.2,3000);
Binder Error:
Unsupported data type for repeat bind function
D select count(*) from myrepeat(1.2::double,3000);
┌──────────────┐
│ count_star() │
│    int64     │
├──────────────┤
│     3000     │
└──────────────┘
D select count(*) from myrepeat(now(),3000);
Binder Error:
Unsupported data type for repeat bind function
D select count(*) from myrepeat(now()::time,3000);
Conversion Error:
Unimplemented type for cast (TIMESTAMP WITH TIME ZONE -> TIME)

LINE 1: select count(*) from myrepeat(now()::time,3000);
                                           ^
D select count(*) from myrepeat(now()::timestamp,3000);
┌──────────────┐
│ count_star() │
│    int64     │
├──────────────┤
│     3000     │
└──────────────┘
D select * from myrepeat(123.456::float,3);
┌─────────┐
│ result  │
│  float  │
├─────────┤
│ 123.456 │
│ 123.456 │
│ 123.456 │
└─────────┘
D select * from myrepeat('abc',3);
┌─────────┐
│ result  │
│ varchar │
├─────────┤
│ abc     │
│ abc     │
│ abc     │
└─────────┘
D select * from myrepeat(now()::date,3);
┌────────────┐
│   result   │
│    date    │
├────────────┤
│ 2025-09-25 │
│ 2025-09-25 │
│ 2025-09-25 │
└────────────┘

程序很容易修改成支持多种类型的range()等表函数。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值