目录
构建并优化 Python 的 Rust 扩展
如果你的 Python 代码运行速度不够快,你可以选择使用编译语言来编写更快的扩展。本文将重点介绍 Rust,它具有以下优势:
- 现代工具链,包括名为 crates.io 的包仓库和内置的构建工具(
cargo
)。 - 出色的 Python 集成和工具支持。Rust 的 Python 支持包是 PyO3。对于打包,你可以使用 setuptools-rust 来与现有的
setuptools
项目集成,或者使用 Maturin 来构建独立的扩展。 - 内存安全和线程安全,相比 C 和 C++,Rust 更不容易崩溃或出现内存损坏问题。
具体来说,我们将:
- 在 Python 中实现一个小算法。
- 将其重新实现为 Rust 扩展。
- 优化 Rust 版本,使其运行更快。
近似计算唯一值的数量
作为示例,我们将计算列表中唯一值的数量。如果我们想要得到精确的结果,实现起来非常简单:
def count_exact(items):
return len(set(items))
set()
只能包含一个元素一次,因此它会去重并计算数量。
这个解决方案的问题是内存使用。如果我们有 10,000,000 个唯一值,我们将创建一个包含 10,000,000 个条目的集合,这将占用大量内存。
因此,如果我们担心内存使用,可以使用一种概率算法来给出一个近似答案。在许多情况下,这已经足够好,并且可以使用更少的内存。我们将使用 Chakraborty、Vinodchandran 和 Meel 提出的一种非常简单的算法:
import random
import math
# Implementation of https://arxiv.org/abs/2301.10191
def count_approx_python(items, epsilon=0.5, delta=0.001):
# Will be used to scale tracked_items upwards:
p = 1
# Items we're currently tracking, limited in length:
tracked_items = set()
# Max length of tracked_items:
max_tracked = round(
((12 / (epsilon ** 2)) *
math.log2(8 * len(items) / delta)
)
for item in items:
tracked_items.discard(item)
if random.random() < p:
tracked_items.add(item)
if len(tracked_items) == max_tracked:
# Drop tracked values with 50% probability.
# Every item in tracked_items now implies the
# final length is twice as large.
tracked_items = {
item for item in tracked_items
if random.random() < 0.5
}
p /= 2
if len(tracked_items) == 0:
raise RuntimeError(
"we got unlucky, no answer"
)
return int(round(len(tracked_items) / p))
运行示例
让我们看一个单词列表的示例:
WORDS = [str(i) for i in range(100_000)] * 100
random.shuffle(WORDS)
我们有 100,000 个不同的单词,这意味着 count_exact()
将创建一个大小为 100,000 的集合。而 count_approx()
的集合大小仅为 1739。它偶尔会有两个集合,但仍然只使用精确算法 3% 的内存。
我们可以比较结果:
print("EXACT", count_exact(WORDS))
print("APPROX", count_approx_python(WORDS))
我连续运行了三次,结果如下:
EXACT 100000
APPROX 99712
EXACT 100000
APPROX 99072
EXACT 100000
APPROX 100864
近似版本的结果有所波动,但非常接近——同时使用了更少的内存。
速度比较
我们可以比较两个实现的速度:
def timeit(f, *args, **kwargs):
start = time.time()
for _ in range(10):
f(*args, **kwargs)
print(f.__name__, (time.time() - start) / 10)
timeit(count_exact, WORDS, "seconds")
timeit(count_approx_python, WORDS, "seconds")
结果如下:
count_exact 0.14 seconds
count_approx_python 0.78 seconds
我们的新函数使用了更少的内存,但速度慢了 5 倍。 让我们看看是否可以通过将其重写为 Rust 来加速 count_approx_python()
。
加速:创建 Rust 项目
使用 Maturin Python 打包工具,我们可以快速创建一个新的 Rust Python 扩展,并使用 PyO3 轻松与 Python 对象交互。
1. 使用 Maturin 初始化项目
你可以使用 pipx
或 pip install maturin
安装 Maturin,然后使用 Maturin 初始化一个全新的基于 PyO3 的项目:
$ maturin new rust_count_approx
✔ 🤷 Which kind of bindings to use?
📖 Documentation: https://maturin.rs/bindings.html · pyo3
✨ Done! New project created rust_count_approx
$ cd rust_count_approx/
这将创建一个基本的 Rust Python 包所需的所有文件:
$ tree
├── Cargo.lock
├── Cargo.toml
├── pyproject.toml
└── src
└── lib.rs
我们可以立即安装这个包,无需进一步设置:
$ pip install .
...
Successfully installed rust_count_approx-0.1.0
2. 添加依赖
Rust 没有内置的随机数生成库,因此我们将使用 Rust 的构建/包管理器 cargo
添加一个第三方 crate(Rust 术语中的可安装包):
$ cargo add rand
Updating crates.io index
Adding rand v0.8.5 to dependencies
Features:
+ alloc
+ getrandom
+ libc
+ rand_chacha
+ std
+ std_rng
- log
- min_const_gen
- nightly
- packed_simd
- serde
- serde1
- simd_support
- small_rng
这将更新 Cargo.toml
文件,其中列出了构建代码所需的 Rust 依赖项。相关部分现在如下所示:
[dependencies]
pyo3 = "0.22.0"
rand = "0.8.5"
pyo3
依赖项是在我们初始化项目模板时由 Maturin 自动添加的。
命令输出中的功能列表是可以在编译时启用的标志,以添加更多功能;我们稍后会使用其中一个。
3. 第一个 Rust 版本
现在,我们需要更新 src/lib.rs
中的 Rust 代码来实现我们的函数:
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::{PySequence, PySet};
use rand::random;
#[pyfunction]
#[pyo3(signature = (items, epsilon=0.5, delta=0.001))]
fn count_approx_rust(
py: Python<'_>,
items: &Bound<PySequence>,
epsilon: f64,
delta: f64,
) -> PyResult<u64> {
let mut p = 1.0;
let mut tracked_items = PySet::empty_bound(py)?;
let max_tracked =
((12.0 / epsilon.powi(2)) *
(8.0 * (items.len()? as f64) / delta).log2()
).round() as usize;
for item in items.iter()? {
let item = item?;
tracked_items.discard(item.clone())?;
if random::<f64>() < p {
tracked_items.add(item)?;
}
if tracked_items.len() == max_tracked {
let mut temp_tracked_items =
PySet::empty_bound(py)?;
for subitem in tracked_items.iter() {
if random::<f64>() < 0.5 {
temp_tracked_items.add(subitem);
}
}
tracked_items = temp_tracked_items;
p /= 2.0;
if tracked_items.len() == 0 {
return Err(PyRuntimeError::new_err(
"we got unlucky, no answer"
));
}
}
}
Ok((tracked_items.len() as f64 / p).round() as u64)
}
#[pymodule]
fn rust_count_approx(
m: &Bound<'_, PyModule>
) -> PyResult<()> {
m.add_function(
wrap_pyfunction!(count_approx_rust, m)?
)?;
Ok(())
}
4. 测量性能
我们的新版本在速度上表现如何?
首先,我们安装新包:
$ pip install .
现在我们可以从 Python 代码中导入新函数并测量其速度:
from rust_count_approx import count_approx_rust
# See above for definition of WORDS and timeit():
timeit(count_approx_rust, WORDS)
以下是新版本的比较:
版本 | 耗时(秒) |
---|---|
Python | 0.78 |
Rust(初始) | 0.37 |
它的速度是原来的两倍。为什么没有更快?
这是因为逻辑上与 Python 实现相同的代码,只是用 Rust 实现并且更冗长。我们仍然以相同的方式与 Python 对象交互,遍历 Python 列表,并频繁与 Python 集合交互。因此,这部分代码的运行速度不会有太大差异。
进一步优化 Rust 代码
接下来,我们将通过四种不同的方式优化代码,所有这些方式都分别提高了性能。
优化 1:链接时优化
首先,我们将启用链接时优化(LTO),Rust 编译器在编译过程的后期进行优化。这意味着编译速度会变慢,但通常会生成运行更快的代码。我们通过在 Cargo.toml
中添加以下内容来实现:
[profile.release]
lto = true
优化 2 和 3:更快的随机数生成
我们还从使用 rand::random()
切换到我们自己管理的随机数生成器(RNG),移除了线程本地查找的开销。
同时,我们切换到使用比默认 RNG 更快的“小型”RNG;它不太安全,但对于我们的目的来说可能已经足够。为此,我们在 Cargo.toml
中添加 smallrng
功能:
[dependencies]
pyo3 = "0.22.0"
rand = {version = "0.8.5", features = ["small_rng"]}
我们还需要修改代码,如下所示。
优化 4:仅存储哈希值
最后,我们从在 Python 集合中存储 Python 对象切换到仅存储 Python 对象的哈希值,通过 obj.__hash__()
计算。如果两个对象的哈希值相同会发生什么?结果会相差 1。
我们已经在使用概率函数,因此我们已经可以接受稍微错误的结果。冲突应该很少见,如果我们有很多唯一值,偶尔相差 1 也不会有太大影响;99712 并不比 99713 更错误。
由于我们不再存储 Python 对象,我们可以切换到使用 Rust 的 std::collections::HashSet
,它有一些更好的 API 并且可能更快。然而,Rust 的 HashSet
会想要对值进行哈希……而我们不想再次哈希它们,它们已经被预哈希了。
为了避免双重哈希,我们还添加了 Rust crate nohash_hasher
:
$ cargo add nohash_hasher
这将允许我们使用一个 Rust 的 HashSet
,它假设存储的值已经是哈希值,可以直接使用而无需进一步哈希。Cargo.toml
的依赖部分现在如下所示:
[dependencies]
nohash-hasher = "0.2.0"
pyo3 = "0.22.0"
rand = {version = "0.8.5", features = ["small_rng"]}
我们还需要更新代码以使用它,如下所示。
优化后的代码
以下是更新后的代码:
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::PySequence;
use rand::{rngs::SmallRng, SeedableRng, Rng};
use nohash_hasher::IntSet;
#[pyfunction]
#[pyo3(signature = (items, epsilon=0.5, delta=0.001))]
fn count_approx_rust(
py: Python<'_>,
items: &Bound<PySequence>,
epsilon: f64,
delta: f64,
) -> PyResult<u64> {
let mut p = 1.0;
let mut tracked_items = IntSet::default();
let mut rng = SmallRng::from_entropy();
let mut random = || rng.gen::<f64>();
let max_tracked =
((12.0 / epsilon.powi(2)) *
(8.0 * (items.len()? as f64) / delta).log2()
).round() as usize;
for item in items.iter()? {
let hash = item?.hash()?;
tracked_items.remove(&hash);
if random() < p {
tracked_items.insert(hash);
}
if tracked_items.len() == max_tracked {
tracked_items.retain(|_| random() < 0.5);
p /= 2.0;
if tracked_items.len() == 0 {
return Err(PyRuntimeError::new_err(
"we got unlucky, no answer"
));
}
}
Ok((tracked_items.len() as f64 / p).round() as u64)
}
再次,我们可以通过以下方式安装更新后的版本:
$ pip install .
以下是优化版本与之前版本的运行时间比较:
版本 | 耗时(秒) |
---|---|
Python | 0.78 |
Rust(初始) | 0.37 |
Rust(优化) | 0.21 |
为什么没有更快,以及更多想法
为什么 Rust 代码没有更快?我们最新的优化版本仍然与 Python 列表 items
交互,并使用 Python 的 __hash__
API 对每个对象进行哈希。这意味着我们仍然受限于 Python API 的速度。如果我们传入 Arrow 列或 NumPy 整数数组,我们可能会运行得更快。
更广泛地说,如果我们想要更快,还应该考虑使用不同的算法。有许多近似计数算法,我选择这个算法是因为它简单易懂,而不是因为它特别快。
大局观:为什么选择 Rust?
Rust 允许我们通过访问编译语言来加速代码。但 C、C++ 或 Cython 也是如此。
与这些语言不同,Rust 具有现代工具链,内置包管理器和构建系统:
- 添加新依赖项就像
cargo add <cratename>
一样简单。浏览可用 crate 的好网站是 lib.rs。 - 不太明显但仍然重要的是:这段代码将在 macOS 和 Windows 上编译,无需额外工作。相比之下,跨平台管理 C 或 C++ 依赖项可能非常痛苦。
Rust 还具有出色的 Python 集成,可以轻松访问 Python API 并使用像 Maturin 这样的打包工具。
Rust 既可以扩展到小型项目(如本文中的项目),也可以扩展到更大的项目,这要归功于其内存和线程安全性以及强大的类型系统。Polars 项目有一个 330K 行的通用 Rust 核心,然后用更多的 Rust 和 Python 包装成一个 Python 扩展。
下次你考虑加速 Python 代码时,试试 Rust。它不是一门简单的语言,但值得你花时间学习。