利用DeepSeek编写的Datafusion自定义表函数read_csv和read_xls

很早就用过Datafusion的python包,也是一个很优秀的数据库SQL查询引擎,最近用docker安装了rust语言环境,就想着能否像duckdb一样用自定义函数扩充Datafusion的功能,幸运的是Datafusion提供了详细的添加自定义函数指南,把这份材料馈入DeepSeek,经过几轮对话修改,成功地实现了支持多种字符集的read_csv和read_xls(实际是xlsx格式)功能。

第一段提示词:

请参考附件,实现datafusion的read_xls() read_csv(fn,encoding) read_pg(url,query)表函数,mpz_add/mpz_mul标量函数/mpz_sum聚合函数,请用中文注释

他给出了除Cargo.toml以外所有的代码,包括注册函数和测试功能的代码,但我需要编译,又请它给出了Cargo.toml。
最初的Cargo.toml

[package]
name = "dfdf"
version = "0.1.0"
edition = "2021"

[dependencies]
datafusion = "24.0"  # 使用最新稳定版
datafusion-expr = "24.0"
datafusion-common = "24.0"
arrow = "40.0"
tokio = { version = "1.0", features = ["full"] }
async-trait = "0.1"
num-bigint = "0.4"
calamine = "0.22"    # Excel 文件支持
encoding_rs = "0.8"  # 编码转换
postgres = { version = "0.19", features = ["with-uuid-1"] }  # PostgreSQL 支持
csv = "1.2"

这里有个陷阱,他是根据已有的知识给出的版本号,而当我把datafusion改成实际最新版本48.0以后,相应的依赖库版本也要修改,否则就会发生版本冲突编译错误。
实际上编译器一开始就提示了。

cargo build
    Updating `mirror` index
     Locking 350 packages to latest Rust 1.87.0 compatible versions
      Adding arrow v40.0.0 (available: v55.1.0)
      Adding async-compression v0.4.19 (available: v0.4.25)
      Adding calamine v0.22.1 (available: v0.28.0)

  Compiling arrow-arith v40.0.0
error[E0034]: multiple applicable items in scope
   --> /usr/local/cargo/registry/src/mirrors.tuna.tsinghua.edu.cn-4dc01642fd091eda/arrow-arith-40.0.0/src/temporal.rs:262:47
    |
262 |     time_fraction_dyn(array, "quarter", |t| t.quarter() as i32)
    |                                               ^^^^^^^ multiple `quarter` found
    |

用如下命令可以看到版本依赖关系

cargo tree -d

arrow v40.0.0
└── dfdf v0.1.0 (/par/dfdf)

arrow v55.1.0
├── datafusion v48.0.0
│   └── dfdf v0.1.0 (/par/dfdf)
├── datafusion-catalog v48.0.0
│   ├── datafusion v48.0.0 (*)

将版本改对以后,相应的错误全都消失了。

新建项目和编译命令如下,其中–vcs none选项表示不把这个目录初始化为一个新的 git 存储库,

cargo new dfdf --bin --vcs none

cd dfdf
cargo build

最终的代码如下:
Cargo.toml

[package]
name = "dfdf"
version = "0.1.0"
edition = "2024"

[dependencies]
datafusion = "48.0"  # 使用最新稳定版
datafusion-expr = "48.0"
datafusion-common = "48.0"
arrow = "55.1"
tokio = { version = "1.0", features = ["full"] }
async-trait = "0.1"
num-bigint = "0.4"
calamine = "0.28"    # Excel 文件支持
encoding_rs = "0.8"  # 编码转换
postgres = { version = "0.19", features = ["with-uuid-1"] }  # PostgreSQL 支持
csv = "1.2"

read_csv.rs

use std::sync::Arc;
use std::fs::File;
use std::io::Read;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::array::{ArrayRef, StringArray};
use datafusion::common::{plan_err, DataFusionError, Result};
use datafusion::catalog::{TableFunctionImpl, TableProvider};
use datafusion::datasource::memory::MemTable;
use datafusion_expr::Expr;
use datafusion::common::ScalarValue;
use encoding_rs::{Encoding, UTF_8};

use arrow::record_batch::RecordBatch;


/// 读取CSV文件的表函数(支持指定编码)
#[derive(Debug)]
pub struct ReadCsvFunction;

impl TableFunctionImpl for ReadCsvFunction {
    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
        // 检查参数数量
        if exprs.len() < 1 || exprs.len() > 2 {
            return plan_err!("read_csv函数需要1-2个参数: 文件路径和可选编码");
        }

        // 解析文件路径
        let file_path = match &exprs[0] {
            Expr::Literal(ScalarValue::Utf8(Some(path)), _) => path,
            _ => return plan_err!("第一个参数必须是字符串类型的文件路径"),
        };

        // 解析编码(默认为UTF-8)
        let encoding = if exprs.len() > 1 {
            match &exprs[1] {
                Expr::Literal(ScalarValue::Utf8(Some(enc)), _) => {
                    Encoding::for_label(enc.as_bytes()).unwrap_or(UTF_8)
                }
                _ => return plan_err!("第二个参数(编码)必须是字符串"),
            }
        } else {
            UTF_8
        };

        // 读取文件内容
        let mut file = File::open(file_path)
            .map_err(|e| DataFusionError::Execution(format!("打开CSV文件失败: {}", e)))?;
        let mut bytes = Vec::new();
        file.read_to_end(&mut bytes)
            .map_err(|e| DataFusionError::Execution(format!("读取CSV文件失败: {}", e)))?;

        // 转换编码
        let (text, _, _) = encoding.decode(&bytes);
        
        // 使用DataFusion内置CSV解析器(简化版)
        let mut reader = csv::Reader::from_reader(text.as_bytes());
        let headers = reader.headers()
            .map_err(|e| DataFusionError::Execution(format!("读取CSV头失败: {}", e)))?;
        
        // 创建Schema
        let fields = headers.iter()
            .map(|h| Field::new(h, DataType::Utf8, true))
            .collect::<Vec<_>>();
        let schema = Arc::new(Schema::new(fields.clone()));

        // 读取数据
        let mut columns = vec![vec![]; fields.len()];
        for record in reader.records() {
            let record = record.map_err(|e| DataFusionError::Execution(format!("解析CSV记录失败: {}", e)))?;
            for (i, field) in record.iter().enumerate() {
                columns[i].push(field.to_string());
            }
        }

        // 转换为RecordBatch
        let columns = columns.into_iter()
            .map(|col| Arc::new(StringArray::from(col)) as ArrayRef)
            .collect::<Vec<_>>();
        let batch = RecordBatch::try_new(schema.clone(), columns)
            .map_err(|e| DataFusionError::Execution(format!("创建RecordBatch失败: {}", e)))?;
        
        // 创建内存表
        let provider = MemTable::try_new(schema, vec![vec![batch]])
            .map_err(|e| DataFusionError::Execution(format!("创建内存表失败: {}", e)))?;
        Ok(Arc::new(provider))
    }
}

impl Default for ReadCsvFunction {
    fn default() -> Self {
        Self
    }
}

read_xls.rs

use std::sync::Arc;
use std::path::Path;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow::array::{ArrayRef, StringArray};
use datafusion::common::{DataFusionError, Result};
use datafusion::catalog::{TableFunctionImpl, TableProvider};
use datafusion::datasource::memory::MemTable;
use datafusion_expr::Expr;
use datafusion::common::ScalarValue;
use calamine::{open_workbook, Reader, Xlsx};

/// 读取Excel文件的表函数
#[derive(Debug)]
pub struct ReadXlsFunction;

impl TableFunctionImpl for ReadXlsFunction {
    fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
        // 检查参数数量
        if exprs.len() != 1 {
            return Err(DataFusionError::Execution(
                "read_xls函数需要1个参数: 文件路径".to_string(),
            ));
        }

        // 解析文件路径参数
        let file_path = match &exprs[0] {
            Expr::Literal(ScalarValue::Utf8(Some(path)), _) => path,
            _ => {
                return Err(DataFusionError::Execution(
                    "第一个参数必须是字符串类型的文件路径".to_string(),
                ))
            }
        };

        // 打开Excel文件
        let mut workbook: Xlsx<_> = open_workbook(Path::new(file_path)).map_err(|e| {
            DataFusionError::Execution(format!("打开Excel文件失败: {}", e))
        })?;

        // 获取第一个工作表
        let range = workbook.worksheet_range_at(0)
            .ok_or_else(|| DataFusionError::Execution("Excel文件中没有工作表".to_string()))?
            .map_err(|e| DataFusionError::Execution(format!("读取工作表失败: {}", e)))?;

        // 创建Schema(基于第一行)
        let first_row = range.rows().next().ok_or_else(|| {
            DataFusionError::Execution("Excel工作表为空".to_string())
        })?;

        let fields = first_row
            .iter()
            .enumerate()
            .map(|(i, _)| Field::new(format!("col_{}", i), DataType::Utf8, true))
            .collect::<Vec<_>>();

        let schema = Arc::new(Schema::new(fields.clone()));

        // 转换数据为RecordBatch(跳过第一行作为标题)
        let mut columns = vec![vec![]; fields.len()];
        for row in range.rows().skip(1) {
            for (i, cell) in row.iter().enumerate() {
                if i < columns.len() {
                    columns[i].push(cell.to_string());
                }
            }
        }

        // 创建StringArray列
        let arrow_columns = columns
            .into_iter()
            .map(|col| Arc::new(StringArray::from(col)) as ArrayRef)
            .collect::<Vec<_>>();

        // 创建RecordBatch
        let batch = RecordBatch::try_new(schema.clone(), arrow_columns).map_err(|e| {
            DataFusionError::Execution(format!("创建RecordBatch失败: {}", e))
        })?;

        // 创建内存表
        let provider = MemTable::try_new(schema, vec![vec![batch]]).map_err(|e| {
            DataFusionError::Execution(format!("创建内存表失败: {}", e))
        })?;

        Ok(Arc::new(provider))
    }
}

impl Default for ReadXlsFunction {
    fn default() -> Self {
        Self
    }
}

main.rs

mod read_csv;
mod read_xls;

use std::sync::Arc;
use datafusion::execution::context::SessionContext;
use datafusion::error::Result;

use read_csv::ReadCsvFunction;
use read_xls::ReadXlsFunction;


/// 注册所有自定义函数到DataFusion上下文
async fn register_custom_functions(ctx: &SessionContext) -> Result<()> {
    // 注册表函数
    ctx.register_udtf("read_csv", Arc::new(ReadCsvFunction::default()));
    // 注册函数
    ctx.register_udtf("read_xls", Arc::new(ReadXlsFunction::default()));
    
    Ok(())
}

#[tokio::main]
async fn main() -> Result<()> {
    let ctx = SessionContext::new();
    
    // 注册自定义函数
    register_custom_functions(&ctx).await?;
    
    println!("所有自定义函数已注册成功!");
    
    // 测试代码
    let df = ctx.sql("SELECT * FROM read_csv('/par/foods.csv') limit 3").await?;
    df.show().await?;

    // 测试代码
    let df1 = ctx.sql("SELECT * FROM read_csv('/par/gbk_file.csv','GBK') limit 3").await?;
    df1.show().await?;


    // 测试代码
    let df2 = ctx.sql("SELECT * FROM read_csv('/par/big5_file.csv','BIG5') limit 3").await?;
    df2.show().await?;

    // 使用示例
    let df3 = ctx.sql("SELECT * FROM read_xls('/par/foods.xlsx')").await?;
    df3.show().await?;
    
    Ok(())
}

执行结果如下

target/debug/dfdf
所有自定义函数已注册成功!
+------------+----------+--------+----------+
| category   | calories | fats_g | sugars_g |
+------------+----------+--------+----------+
| vegetables | 45       | 0.5    | 2        |
| seafood    | 150      | 5      | 0        |
| meat       | 100      | 5      | 0        |
+------------+----------+--------+----------+
+----+---------------------------+-------+
| id |  lang                     |  code |
+----+---------------------------+-------+
| 1  |  中文                     | GBK   |
| 2  |  〇镕镚閫閬               |  GBK  |
| 3  |  烎玊奣嘦勥巭嫑恏兲氼忈炛 |  GBK  |
+----+---------------------------+-------+
+----+--------+-------+
| id |  lang  |  code |
+----+--------+-------+
| 1  |  中文  | GBK   |
| 2  |  閫閬  |  GBK  |
| 3  |  烎玊  |  GBK  |
+----+--------+-------+
+------------+-------+-------+-------+
| col_0      | col_1 | col_2 | col_3 |
+------------+-------+-------+-------+
| vegetables | 45    | 0.5   | 2     |
| 海鲜       | 150   | 5     | 0     |
|| 100   | 5     | 0     |
| fruit      | 60    | 0     | 11    |
+------------+-------+-------+-------+

总结:DeepSeek极大地降低了学习一门新程序语言的门槛,只要有任意一种语言基础,和阅读简单英文的能力,配合良好的文档,任何人都能做出有用的软件。

### 关于 Fusion_dataset 函数 对于 `Fusion_dataset` 函数的具体描述,在现有参考资料中并未直接提及此函数的相关细节。然而,基于 TensorFlow 数据处理的一般实践,可以推测该函数可能用于融合多个数据集或将不同来源的数据组合在一起形成一个新的数据集。 #### 假设性的功能概述 假设 `Fusion_dataset` 是一个自定义或特定框架内的函数,其主要作用可能是: - **合并操作**:接受两个或更多个输入数据集作为参数,并返回一个新的综合数据集。 - **兼容性调整**:确保来自不同源的数据能够被适当地转换标准化以便统一使用。 - **增强灵活性**:提供选项来指定如何处理重复项、缺失值以及是否保留原始索引等特性。 #### 示例代码展示 下面给出一段 Python 伪代码,模拟了上述提到的功能: ```python def fusion_dataset(*datasets, remove_duplicates=True, handle_missing='drop', preserve_index=False): """ 将多个 pandas DataFrame 或其他类型的 dataset 合并为单个数据集 参数: *datasets: 可变数量的 datasets (pandas.DataFrame 类型或其他支持的对象) remove_duplicates(bool): 是否移除重复记录,默认 True handle_missing(str): 处理缺失值的方式 ('drop'/'fill') ,默认 'drop' preserve_index(bool): 是否保持原有索引,默认 False 返回: fused_df(pandas.DataFrame): 融合后的数据框 """ import pandas as pd # 初始化空列收集所有传入的数据集 all_data = [] for ds in datasets: if isinstance(ds, pd.DataFrame): all_data.append(ds) # 进行实际的拼接工作 fused_df = pd.concat(all_data, axis=0, ignore_index=(not preserve_index)) # 移除重复条目(如果需要) if remove_duplicates: fused_df.drop_duplicates(inplace=True) # 对缺失值进行相应处理 if handle_missing == 'drop': fused_df.dropna(inplace=True) elif handle_missing == 'fill': fused_df.fillna(method="ffill", inplace=True) return fused_df ``` 这段代码展示了如何通过给定的一些配置选项灵活地将不同的数据集结合起来。请注意这只是一个概念验证性质的例子,并不代任何真实存在的 API 设计[^1]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值