在这篇文章中,我们将为 Postgres 实现 vector 类型:
CREATE TABLE items (v vector(3));
Postgres 扩展结构和 pgrx 包装器
在实现它之前,让我们先看看典型的扩展结构,以及 pgrx 如何为我们简化它。
典型的 Postgres 扩展可以大致分为 2 层:
- 实现,通常使用 C 等低级语言完成。
- 将实现粘合到 Postgres 的高级 SQL 语句。
- 指定扩展的一些基本属性的控制文件。
如果你看一下 pgvector 的源代码,这个 3 层结构非常明显,src 目录
用于 C 代码,sql 目录包含更高级的 SQL 胶水,还有一个
.control 文件。那么 pgrx 如何使扩展构建更容易?
- 它用 Rust 包装 Postgres C API
正如我们所说,即使我们用 Rust 构建扩展,Postgres 的 API 仍然是 C,pgrx 尝试将它们包装在 Rust 中,这样我们就不需要为 C 烦恼了。
- 如果可能,使用 Rust 宏生成 SQL 胶水
稍后我们将看到 pgrx 可以自动为我们生成 SQL 胶水。
- pgrx 为我们生成
.control文件
CREATE TYPE vector
我们来定义我们的 Vector 类型,使用 std::vec::Vec 看起来非常简单,而且由于 vector 需要存储浮点数,我们在这里使用 f64:
struct Vector {
value: Vec<f64>
}
然后呢?
用于创建新类型的 SQL 语句是 CREATE TYPE ...,从它的 文档,我们会知道我们正在实现的 vector 类型是一个 基类型,要创建基类型,需要支持函数 input_function 和 output_function。而且由于它需要采用使用 modifer 实现的维度参数(vector(DIMENSION)),因此还需要函数 type_modifier_input_function 和 type_modifier_output_function。因此,我们需要为我们的 Vector 类型实现这 4 个函数。
input_function
引用文档,
input_function将类型的外部文本表示转换为为该类型定义的运算符和函数使用的内部表示。输入函数可以声明为采用一个
cstring类型的参数,也可以声明为采用三个cstring、oid、integer类型的参数。第一个参数是作为 C 字符串的输入文本,第二个参数是类型自己的 OID(数组类型除外,它们接收其元素类型的 OID),第三个是目标列的 typmod(如果已知)(如果未知,则传递 -1)。输入函数必须返回数据类型本身的值。
好的,从文档来看,这个 input_function 用于反序列化,serde 是 Rust 中最流行的反序列化库,所以让我们使用它。对于参数,由于 vector 需要类型修饰符,我们需要它接受 3 个参数。我们的 input_function 如下所示:
#[pg_extern(immutable, strict, parallel_safe, requires = [ "shell_type" ])]
fn vector_input(
input: &CStr,
_oid: pg_sys::Oid,
type_modifier: i32,
) -> Vector {
let value = match serde_json::from_str::<Vec<f64>>(
input.to_str().expect("expect input to be UTF-8 encoded"),
) {
Ok(v) => v,
Err(e) => {
pgrx::error!("failed to deserialize the input string due to error {}", e)
}
};
let dimension = match u16::try_from(value.len()) {
Ok(d) => d,
Err(_) => {
pgrx::error!("this vector's dimension [{}] is too large", value.len());
}
};
// cast should be safe as dimension should be a positive
let expected_dimension = match u16::try_from(type_modifier) {
Ok(d) => d,
Err(_) => {
panic!("failed to cast stored dimension [{}] to u16", type_modifier);
}
};
// check the dimension
if dimension != expected_dimension {
pgrx::error!(
"mismatched dimension, expected {}, found {}",
expected_dimension,
dimension
);
}
Vector {
value }
}
这有一大堆东西,让我们逐一研究一下。
#[pg_extern(immutable, strict, parallel_safe, require = [ "shell_type" ])]
如果你用 pg_extern 标记一个函数,那么 pgrx 会自动为你生成类似 CREATE FUNCTION <你的函数> 的 SQL,immutable, strict, parallel_safe 是你认为你的函数具有的属性,它们与 CREATE FUNCTION 文档 中列出的属性相对应。因为这个 Rust 宏用于生成 SQL,并且 SQL 可以相互依赖,所以这个 requires = [ "shell_type" ] 用于明确这种依赖关系。
shell_type 是另一个定义 shell 类型的 SQL 代码段的名称,什么是 shell 类型?它的行为就像一个占位符,这样我们在完全实现它之前就可以有一个 vector 类型来使用。此 #[pg_extern] 宏生成的 SQL 将是:
CREATE FUNCTION "vector_input"(
"input" cstring,
"_oid" oid,
"type_modifier" INT
) RETURNS vector
如您所见,此函数 RETURNS vector,但在实现这 4 个必需函数之前,我们如何才能拥有 vector 类型?

Shell 类型正是为此而生!我们可以定义一个 shell 类型(虚拟类型,不需要提供任何函数),并让我们的函数依赖于它:

pgrx 不会为我们定义这个 shell 类型,我们需要在 SQL 中手动执行此操作,如下所示:
extension_sql!(
r#"CREATE TYPE vector; -- shell type"#,
name = "shell_type"
);
extension_sql!() 宏允许我们在 Rust 代码中编写 SQL,然后 pgrx 会将其包含在生成的 SQL 脚本中。name = "shell_type" 指定此 SQL 代码段的标识符,可用于引用它。我们的 vector_input() 函数依赖于此 shell 类型,因此它 requires = [ "shell_type" ]。
fn vector_input(
input: &CStr,
_oid: pg_sys::Oid,
type_modifier: i32,
) -> Vector {
input 参数是一个表示我们的向量输入文本的字符串,_oid 以 _ 为前缀,因为我们不需要它。type_modifier 参数的类型为 i32,这就是类型修饰符在 Postgres 中的存储方式。当我们实现类型修饰符输入/输出函数时,我们将再次看到它。
let value = match serde_json::from_str::<Vec<f64>>(
input.to_str().expect("expect input to be UTF-8coded"),
) {
Ok(v) => v,
Err(e) => {
pgrx::error!("failed to deserialize the input string due to error {}", e)
}
};
然后我们将 input 转换为 UTF-8 编码的 &str 并将其传递给 serde_json::from_str()。输入文本应该是 UTF-8 编码的,所以我们应该是安全的。如果在反序列化过程中发生任何错误,只需使用 pgrx::error!() 输出错误,它将在 error 级别记录并终止当前事务。
let dimension = match u16::try_from(value.len()) {
Ok(d) => d,
Err(_) => {
pgrx::error!("此向量的维度 [{}] 太大", value.len());
}
};
// cast should be safe as dimension should be a positive
let expected_dimension = match u16::try_from(type_modifier) {
Ok(d) => d,
Err(_) => {
panic!("无法将存储的维度 [{}] 转换为 u16", type_modifier);
}
};
我们支持的最大维度是 u16::MAX,我们这样做只是因为这是 pgvector 所做的。
// check the dimension
if dimension != expected_dimension {
pgrx::error!(
"mismatched dimension, expected {}, found {}",
expected_dimension,
dimension
);
}
Vector {
value }
最后,我们检查输入向量是否具有预期的维度,如果没有,则出错。否则,我们返回解析后的向量。
output_function
output_function 执行反向操作,它将给定的向量序列化为字符串。这是我们的实现:
#[pg_extern(immutable, strict, parallel_safe, require = [ "shell_type" ])]
fn vector_output(value: Vector) -> CString {
let value_serialized_string = serde_json::to_string(&value).unwrap();
CString::new(value_serialized_string).expect("中间不应该有 NUL")
}
我们只需序列化 Vec<f64> 并将其返回到 CString 中,非常简单。
type_modifier_input_function
type_modifier_input_function 应该解析输入修饰符,检查解析的修饰符,如果有效,则将其编码为整数,这是 Postgres 存储类型修饰符的方式。
一个类型可以接受多个类型修饰符,用 , 分隔,这就是我们在这里看到数组的原因。
#[pg_extern(immutable, strict, parallel_safe, requires = [ "shell_type" ])]
fn vector_modifier_input(list: pgrx::datum::Array<&CStr>) -> i32 {
if list.len() != 1 {
pgrx::error!("too many modifiers, expect 1")
}
let modifier = list
.get(0)
.expect

最低0.47元/天 解锁文章
397

被折叠的 条评论
为什么被折叠?



