很早就用过Datafusion的python包,也是一个很优秀的数据库SQL查询引擎,最近用docker安装了rust语言环境,就想着能否像duckdb一样用自定义函数扩充Datafusion的功能,幸运的是Datafusion提供了详细的添加自定义函数指南,把这份材料馈入DeepSeek,经过几轮对话修改,成功地实现了支持多种字符集的read_csv和read_xls(实际是xlsx格式)功能。
第一段提示词:
请参考附件,实现datafusion的read_xls() read_csv(fn,encoding) read_pg(url,query)表函数,mpz_add/mpz_mul标量函数/mpz_sum聚合函数,请用中文注释
他给出了除Cargo.toml以外所有的代码,包括注册函数和测试功能的代码,但我需要编译,又请它给出了Cargo.toml。
最初的Cargo.toml
[package]
name = "dfdf"
version = "0.1.0"
edition = "2021"
[dependencies]
datafusion = "24.0" # 使用最新稳定版
datafusion-expr = "24.0"
datafusion-common = "24.0"
arrow = "40.0"
tokio = { version = "1.0", features = ["full"] }
async-trait = "0.1"
num-bigint = "0.4"
calamine = "0.22" # Excel 文件支持
encoding_rs = "0.8" # 编码转换
postgres = { version = "0.19", features = ["with-uuid-1"] } # PostgreSQL 支持
csv = "1.2"
这里有个陷阱,他是根据已有的知识给出的版本号,而当我把datafusion改成实际最新版本48.0以后,相应的依赖库版本也要修改,否则就会发生版本冲突编译错误。
实际上编译器一开始就提示了。
cargo build
Updating `mirror` index
Locking 350 packages to latest Rust 1.87.0 compatible versions
Adding arrow v40.0.0 (available: v55.1.0)
Adding async-compression v0.4.19 (available: v0.4.25)
Adding calamine v0.22.1 (available: v0.28.0)
Compiling arrow-arith v40.0.0
error[E0034]: multiple applicable items in scope
--> /usr/local/cargo/registry/src/mirrors.tuna.tsinghua.edu.cn-4dc01642fd091eda/arrow-arith-40.0.0/src/temporal.rs:262:47
|
262 | time_fraction_dyn(array, "quarter", |t| t.quarter() as i32)
| ^^^^^^^ multiple `quarter` found
|
用如下命令可以看到版本依赖关系
cargo tree -d
arrow v40.0.0
└── dfdf v0.1.0 (/par/dfdf)
arrow v55.1.0
├── datafusion v48.0.0
│ └── dfdf v0.1.0 (/par/dfdf)
├── datafusion-catalog v48.0.0
│ ├── datafusion v48.0.0 (*)
将版本改对以后,相应的错误全都消失了。
新建项目和编译命令如下,其中–vcs none选项表示不把这个目录初始化为一个新的 git 存储库,
cargo new dfdf --bin --vcs none
cd dfdf
cargo build
最终的代码如下:
Cargo.toml
[package]
name = "dfdf"
version = "0.1.0"
edition = "2024"
[dependencies]
datafusion = "48.0" # 使用最新稳定版
datafusion-expr = "48.0"
datafusion-common = "48.0"
arrow = "55.1"
tokio = { version = "1.0", features = ["full"] }
async-trait = "0.1"
num-bigint = "0.4"
calamine = "0.28" # Excel 文件支持
encoding_rs = "0.8" # 编码转换
postgres = { version = "0.19", features = ["with-uuid-1"] } # PostgreSQL 支持
csv = "1.2"
read_csv.rs
use std::sync::Arc;
use std::fs::File;
use std::io::Read;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::array::{ArrayRef, StringArray};
use datafusion::common::{plan_err, DataFusionError, Result};
use datafusion::catalog::{TableFunctionImpl, TableProvider};
use datafusion::datasource::memory::MemTable;
use datafusion_expr::Expr;
use datafusion::common::ScalarValue;
use encoding_rs::{Encoding, UTF_8};
use arrow::record_batch::RecordBatch;
/// 读取CSV文件的表函数(支持指定编码)
#[derive(Debug)]
pub struct ReadCsvFunction;
impl TableFunctionImpl for ReadCsvFunction {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
// 检查参数数量
if exprs.len() < 1 || exprs.len() > 2 {
return plan_err!("read_csv函数需要1-2个参数: 文件路径和可选编码");
}
// 解析文件路径
let file_path = match &exprs[0] {
Expr::Literal(ScalarValue::Utf8(Some(path)), _) => path,
_ => return plan_err!("第一个参数必须是字符串类型的文件路径"),
};
// 解析编码(默认为UTF-8)
let encoding = if exprs.len() > 1 {
match &exprs[1] {
Expr::Literal(ScalarValue::Utf8(Some(enc)), _) => {
Encoding::for_label(enc.as_bytes()).unwrap_or(UTF_8)
}
_ => return plan_err!("第二个参数(编码)必须是字符串"),
}
} else {
UTF_8
};
// 读取文件内容
let mut file = File::open(file_path)
.map_err(|e| DataFusionError::Execution(format!("打开CSV文件失败: {}", e)))?;
let mut bytes = Vec::new();
file.read_to_end(&mut bytes)
.map_err(|e| DataFusionError::Execution(format!("读取CSV文件失败: {}", e)))?;
// 转换编码
let (text, _, _) = encoding.decode(&bytes);
// 使用DataFusion内置CSV解析器(简化版)
let mut reader = csv::Reader::from_reader(text.as_bytes());
let headers = reader.headers()
.map_err(|e| DataFusionError::Execution(format!("读取CSV头失败: {}", e)))?;
// 创建Schema
let fields = headers.iter()
.map(|h| Field::new(h, DataType::Utf8, true))
.collect::<Vec<_>>();
let schema = Arc::new(Schema::new(fields.clone()));
// 读取数据
let mut columns = vec![vec![]; fields.len()];
for record in reader.records() {
let record = record.map_err(|e| DataFusionError::Execution(format!("解析CSV记录失败: {}", e)))?;
for (i, field) in record.iter().enumerate() {
columns[i].push(field.to_string());
}
}
// 转换为RecordBatch
let columns = columns.into_iter()
.map(|col| Arc::new(StringArray::from(col)) as ArrayRef)
.collect::<Vec<_>>();
let batch = RecordBatch::try_new(schema.clone(), columns)
.map_err(|e| DataFusionError::Execution(format!("创建RecordBatch失败: {}", e)))?;
// 创建内存表
let provider = MemTable::try_new(schema, vec![vec![batch]])
.map_err(|e| DataFusionError::Execution(format!("创建内存表失败: {}", e)))?;
Ok(Arc::new(provider))
}
}
impl Default for ReadCsvFunction {
fn default() -> Self {
Self
}
}
read_xls.rs
use std::sync::Arc;
use std::path::Path;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
use arrow::array::{ArrayRef, StringArray};
use datafusion::common::{DataFusionError, Result};
use datafusion::catalog::{TableFunctionImpl, TableProvider};
use datafusion::datasource::memory::MemTable;
use datafusion_expr::Expr;
use datafusion::common::ScalarValue;
use calamine::{open_workbook, Reader, Xlsx};
/// 读取Excel文件的表函数
#[derive(Debug)]
pub struct ReadXlsFunction;
impl TableFunctionImpl for ReadXlsFunction {
fn call(&self, exprs: &[Expr]) -> Result<Arc<dyn TableProvider>> {
// 检查参数数量
if exprs.len() != 1 {
return Err(DataFusionError::Execution(
"read_xls函数需要1个参数: 文件路径".to_string(),
));
}
// 解析文件路径参数
let file_path = match &exprs[0] {
Expr::Literal(ScalarValue::Utf8(Some(path)), _) => path,
_ => {
return Err(DataFusionError::Execution(
"第一个参数必须是字符串类型的文件路径".to_string(),
))
}
};
// 打开Excel文件
let mut workbook: Xlsx<_> = open_workbook(Path::new(file_path)).map_err(|e| {
DataFusionError::Execution(format!("打开Excel文件失败: {}", e))
})?;
// 获取第一个工作表
let range = workbook.worksheet_range_at(0)
.ok_or_else(|| DataFusionError::Execution("Excel文件中没有工作表".to_string()))?
.map_err(|e| DataFusionError::Execution(format!("读取工作表失败: {}", e)))?;
// 创建Schema(基于第一行)
let first_row = range.rows().next().ok_or_else(|| {
DataFusionError::Execution("Excel工作表为空".to_string())
})?;
let fields = first_row
.iter()
.enumerate()
.map(|(i, _)| Field::new(format!("col_{}", i), DataType::Utf8, true))
.collect::<Vec<_>>();
let schema = Arc::new(Schema::new(fields.clone()));
// 转换数据为RecordBatch(跳过第一行作为标题)
let mut columns = vec![vec![]; fields.len()];
for row in range.rows().skip(1) {
for (i, cell) in row.iter().enumerate() {
if i < columns.len() {
columns[i].push(cell.to_string());
}
}
}
// 创建StringArray列
let arrow_columns = columns
.into_iter()
.map(|col| Arc::new(StringArray::from(col)) as ArrayRef)
.collect::<Vec<_>>();
// 创建RecordBatch
let batch = RecordBatch::try_new(schema.clone(), arrow_columns).map_err(|e| {
DataFusionError::Execution(format!("创建RecordBatch失败: {}", e))
})?;
// 创建内存表
let provider = MemTable::try_new(schema, vec![vec![batch]]).map_err(|e| {
DataFusionError::Execution(format!("创建内存表失败: {}", e))
})?;
Ok(Arc::new(provider))
}
}
impl Default for ReadXlsFunction {
fn default() -> Self {
Self
}
}
main.rs
mod read_csv;
mod read_xls;
use std::sync::Arc;
use datafusion::execution::context::SessionContext;
use datafusion::error::Result;
use read_csv::ReadCsvFunction;
use read_xls::ReadXlsFunction;
/// 注册所有自定义函数到DataFusion上下文
async fn register_custom_functions(ctx: &SessionContext) -> Result<()> {
// 注册表函数
ctx.register_udtf("read_csv", Arc::new(ReadCsvFunction::default()));
// 注册函数
ctx.register_udtf("read_xls", Arc::new(ReadXlsFunction::default()));
Ok(())
}
#[tokio::main]
async fn main() -> Result<()> {
let ctx = SessionContext::new();
// 注册自定义函数
register_custom_functions(&ctx).await?;
println!("所有自定义函数已注册成功!");
// 测试代码
let df = ctx.sql("SELECT * FROM read_csv('/par/foods.csv') limit 3").await?;
df.show().await?;
// 测试代码
let df1 = ctx.sql("SELECT * FROM read_csv('/par/gbk_file.csv','GBK') limit 3").await?;
df1.show().await?;
// 测试代码
let df2 = ctx.sql("SELECT * FROM read_csv('/par/big5_file.csv','BIG5') limit 3").await?;
df2.show().await?;
// 使用示例
let df3 = ctx.sql("SELECT * FROM read_xls('/par/foods.xlsx')").await?;
df3.show().await?;
Ok(())
}
执行结果如下
target/debug/dfdf
所有自定义函数已注册成功!
+------------+----------+--------+----------+
| category | calories | fats_g | sugars_g |
+------------+----------+--------+----------+
| vegetables | 45 | 0.5 | 2 |
| seafood | 150 | 5 | 0 |
| meat | 100 | 5 | 0 |
+------------+----------+--------+----------+
+----+---------------------------+-------+
| id | lang | code |
+----+---------------------------+-------+
| 1 | 中文 | GBK |
| 2 | 〇镕镚閫閬 | GBK |
| 3 | 烎玊奣嘦勥巭嫑恏兲氼忈炛 | GBK |
+----+---------------------------+-------+
+----+--------+-------+
| id | lang | code |
+----+--------+-------+
| 1 | 中文 | GBK |
| 2 | 閫閬 | GBK |
| 3 | 烎玊 | GBK |
+----+--------+-------+
+------------+-------+-------+-------+
| col_0 | col_1 | col_2 | col_3 |
+------------+-------+-------+-------+
| vegetables | 45 | 0.5 | 2 |
| 海鲜 | 150 | 5 | 0 |
| 肉 | 100 | 5 | 0 |
| fruit | 60 | 0 | 11 |
+------------+-------+-------+-------+
总结:DeepSeek极大地降低了学习一门新程序语言的门槛,只要有任意一种语言基础,和阅读简单英文的能力,配合良好的文档,任何人都能做出有用的软件。