应用-构建并优化 Python 的 Rust 扩展

目录

构建并优化 Python 的 Rust 扩展

如果你的 Python 代码运行速度不够快,你可以选择使用编译语言来编写更快的扩展。本文将重点介绍 Rust,它具有以下优势:

  • 现代工具链,包括名为 crates.io 的包仓库和内置的构建工具(cargo)。
  • 出色的 Python 集成和工具支持。Rust 的 Python 支持包是 PyO3。对于打包,你可以使用 setuptools-rust 来与现有的 setuptools 项目集成,或者使用 Maturin 来构建独立的扩展。
  • 内存安全和线程安全,相比 C 和 C++,Rust 更不容易崩溃或出现内存损坏问题。

具体来说,我们将:

  1. 在 Python 中实现一个小算法。
  2. 将其重新实现为 Rust 扩展。
  3. 优化 Rust 版本,使其运行更快。

近似计算唯一值的数量

作为示例,我们将计算列表中唯一值的数量。如果我们想要得到精确的结果,实现起来非常简单:

def count_exact(items):
    return len(set(items))

set() 只能包含一个元素一次,因此它会去重并计算数量。

这个解决方案的问题是内存使用。如果我们有 10,000,000 个唯一值,我们将创建一个包含 10,000,000 个条目的集合,这将占用大量内存。

因此,如果我们担心内存使用,可以使用一种概率算法来给出一个近似答案。在许多情况下,这已经足够好,并且可以使用更少的内存。我们将使用 Chakraborty、Vinodchandran 和 Meel 提出的一种非常简单的算法

import random
import math

# Implementation of https://arxiv.org/abs/2301.10191
def count_approx_python(items, epsilon=0.5, delta=0.001):
    # Will be used to scale tracked_items upwards:
    p = 1
    # Items we're currently tracking, limited in length:
    tracked_items = set()
    # Max length of tracked_items:
    max_tracked = round(
        ((12 / (epsilon ** 2)) *
        math.log2(8 * len(items) / delta)
    )
    for item in items:
        tracked_items.discard(item)
        if random.random() < p:
            tracked_items.add(item)
        if len(tracked_items) == max_tracked:
            # Drop tracked values with 50% probability.
            # Every item in tracked_items now implies the
            # final length is twice as large.
            tracked_items = {
                item for item in tracked_items
                if random.random() < 0.5
            }
            p /= 2
            if len(tracked_items) == 0:
                raise RuntimeError(
                    "we got unlucky, no answer"
                )
    return int(round(len(tracked_items) / p))

运行示例

让我们看一个单词列表的示例:

WORDS = [str(i) for i in range(100_000)] * 100
random.shuffle(WORDS)

我们有 100,000 个不同的单词,这意味着 count_exact() 将创建一个大小为 100,000 的集合。而 count_approx() 的集合大小仅为 1739。它偶尔会有两个集合,但仍然只使用精确算法 3% 的内存。

我们可以比较结果:

print("EXACT", count_exact(WORDS))
print("APPROX", count_approx_python(WORDS))

我连续运行了三次,结果如下:

EXACT 100000
APPROX 99712

EXACT 100000
APPROX 99072

EXACT 100000
APPROX 100864

近似版本的结果有所波动,但非常接近——同时使用了更少的内存。

速度比较

我们可以比较两个实现的速度:

def timeit(f, *args, **kwargs):
    start = time.time()
    for _ in range(10):
        f(*args, **kwargs)
    print(f.__name__, (time.time() - start) / 10)

timeit(count_exact, WORDS, "seconds")
timeit(count_approx_python, WORDS, "seconds")

结果如下:

count_exact 0.14 seconds
count_approx_python 0.78 seconds

我们的新函数使用了更少的内存,但速度慢了 5 倍。 让我们看看是否可以通过将其重写为 Rust 来加速 count_approx_python()

加速:创建 Rust 项目

使用 Maturin Python 打包工具,我们可以快速创建一个新的 Rust Python 扩展,并使用 PyO3 轻松与 Python 对象交互。

1. 使用 Maturin 初始化项目

你可以使用 pipxpip install maturin 安装 Maturin,然后使用 Maturin 初始化一个全新的基于 PyO3 的项目:

$ maturin new rust_count_approx
✔ 🤷 Which kind of bindings to use?
  📖 Documentation: https://maturin.rs/bindings.html · pyo3
  ✨ Done! New project created rust_count_approx
$ cd rust_count_approx/

这将创建一个基本的 Rust Python 包所需的所有文件:

$ tree
├── Cargo.lock
├── Cargo.toml
├── pyproject.toml
└── src
    └── lib.rs

我们可以立即安装这个包,无需进一步设置:

$ pip install .
...
Successfully installed rust_count_approx-0.1.0

2. 添加依赖

Rust 没有内置的随机数生成库,因此我们将使用 Rust 的构建/包管理器 cargo 添加一个第三方 crate(Rust 术语中的可安装包):

$ cargo add rand
    Updating crates.io index
      Adding rand v0.8.5 to dependencies
             Features:
             + alloc
             + getrandom
             + libc
             + rand_chacha
             + std
             + std_rng
             - log
             - min_const_gen
             - nightly
             - packed_simd
             - serde
             - serde1
             - simd_support
             - small_rng

这将更新 Cargo.toml 文件,其中列出了构建代码所需的 Rust 依赖项。相关部分现在如下所示:

[dependencies]
pyo3 = "0.22.0"
rand = "0.8.5"

pyo3 依赖项是在我们初始化项目模板时由 Maturin 自动添加的。

命令输出中的功能列表是可以在编译时启用的标志,以添加更多功能;我们稍后会使用其中一个。

3. 第一个 Rust 版本

现在,我们需要更新 src/lib.rs 中的 Rust 代码来实现我们的函数:

use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::{PySequence, PySet};
use rand::random;

#[pyfunction]
#[pyo3(signature = (items, epsilon=0.5, delta=0.001))]
fn count_approx_rust(
    py: Python<'_>,
    items: &Bound<PySequence>,
    epsilon: f64,
    delta: f64,
) -> PyResult<u64> {
    let mut p = 1.0;
    let mut tracked_items = PySet::empty_bound(py)?;
    let max_tracked =
        ((12.0 / epsilon.powi(2)) *
         (8.0 * (items.len()? as f64) / delta).log2()
        ).round() as usize;
    for item in items.iter()? {
        let item = item?;
        tracked_items.discard(item.clone())?;
        if random::<f64>() < p {
            tracked_items.add(item)?;
        }
        if tracked_items.len() == max_tracked {
            let mut temp_tracked_items =
                PySet::empty_bound(py)?;
            for subitem in tracked_items.iter() {
                if random::<f64>() < 0.5 {
                    temp_tracked_items.add(subitem);
                }
            }
            tracked_items = temp_tracked_items;
            p /= 2.0;
            if tracked_items.len() == 0 {
                return Err(PyRuntimeError::new_err(
                    "we got unlucky, no answer"
                ));
            }
        }
    }
    Ok((tracked_items.len() as f64 / p).round() as u64)
}

#[pymodule]
fn rust_count_approx(
    m: &Bound<'_, PyModule>
) -> PyResult<()> {
    m.add_function(
        wrap_pyfunction!(count_approx_rust, m)?
    )?;
    Ok(())
}

4. 测量性能

我们的新版本在速度上表现如何?

首先,我们安装新包:

$ pip install .

现在我们可以从 Python 代码中导入新函数并测量其速度:

from rust_count_approx import count_approx_rust

# See above for definition of WORDS and timeit():
timeit(count_approx_rust, WORDS)

以下是新版本的比较:

版本耗时(秒)
Python0.78
Rust(初始)0.37

它的速度是原来的两倍。为什么没有更快?

这是因为逻辑上与 Python 实现相同的代码,只是用 Rust 实现并且更冗长。我们仍然以相同的方式与 Python 对象交互,遍历 Python 列表,并频繁与 Python 集合交互。因此,这部分代码的运行速度不会有太大差异。

进一步优化 Rust 代码

接下来,我们将通过四种不同的方式优化代码,所有这些方式都分别提高了性能。

优化 1:链接时优化

首先,我们将启用链接时优化(LTO),Rust 编译器在编译过程的后期进行优化。这意味着编译速度会变慢,但通常会生成运行更快的代码。我们通过在 Cargo.toml 中添加以下内容来实现:

[profile.release]
lto = true

优化 2 和 3:更快的随机数生成

我们还从使用 rand::random() 切换到我们自己管理的随机数生成器(RNG),移除了线程本地查找的开销。

同时,我们切换到使用比默认 RNG 更快的“小型”RNG;它不太安全,但对于我们的目的来说可能已经足够。为此,我们在 Cargo.toml 中添加 smallrng 功能:

[dependencies]
pyo3 = "0.22.0"
rand = {version = "0.8.5", features = ["small_rng"]}

我们还需要修改代码,如下所示。

优化 4:仅存储哈希值

最后,我们从在 Python 集合中存储 Python 对象切换到仅存储 Python 对象的哈希值,通过 obj.__hash__() 计算。如果两个对象的哈希值相同会发生什么?结果会相差 1。

我们已经在使用概率函数,因此我们已经可以接受稍微错误的结果。冲突应该很少见,如果我们有很多唯一值,偶尔相差 1 也不会有太大影响;99712 并不比 99713 更错误。

由于我们不再存储 Python 对象,我们可以切换到使用 Rust 的 std::collections::HashSet,它有一些更好的 API 并且可能更快。然而,Rust 的 HashSet 会想要对值进行哈希……而我们不想再次哈希它们,它们已经被预哈希了。

为了避免双重哈希,我们还添加了 Rust crate nohash_hasher

$ cargo add nohash_hasher

这将允许我们使用一个 Rust 的 HashSet,它假设存储的值已经是哈希值,可以直接使用而无需进一步哈希。Cargo.toml 的依赖部分现在如下所示:

[dependencies]
nohash-hasher = "0.2.0"
pyo3 = "0.22.0"
rand = {version = "0.8.5", features = ["small_rng"]}

我们还需要更新代码以使用它,如下所示。

优化后的代码

以下是更新后的代码:

use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::PySequence;
use rand::{rngs::SmallRng, SeedableRng, Rng};
use nohash_hasher::IntSet;

#[pyfunction]
#[pyo3(signature = (items, epsilon=0.5, delta=0.001))]
fn count_approx_rust(
    py: Python<'_>,
    items: &Bound<PySequence>,
    epsilon: f64,
    delta: f64,
) -> PyResult<u64> {
    let mut p = 1.0;
    let mut tracked_items = IntSet::default();
    let mut rng = SmallRng::from_entropy();
    let mut random = || rng.gen::<f64>();
    let max_tracked =
        ((12.0 / epsilon.powi(2)) *
         (8.0 * (items.len()? as f64) / delta).log2()
        ).round() as usize;
    for item in items.iter()? {
        let hash = item?.hash()?;
        tracked_items.remove(&hash);
        if random() < p {
            tracked_items.insert(hash);
        }
        if tracked_items.len() == max_tracked {
            tracked_items.retain(|_| random() < 0.5);
            p /= 2.0;
            if tracked_items.len() == 0 {
                return Err(PyRuntimeError::new_err(
                    "we got unlucky, no answer"
                ));
        }
    }
    Ok((tracked_items.len() as f64 / p).round() as u64)
}

再次,我们可以通过以下方式安装更新后的版本:

$ pip install .

以下是优化版本与之前版本的运行时间比较:

版本耗时(秒)
Python0.78
Rust(初始)0.37
Rust(优化)0.21

为什么没有更快,以及更多想法

为什么 Rust 代码没有更快?我们最新的优化版本仍然与 Python 列表 items 交互,并使用 Python 的 __hash__ API 对每个对象进行哈希。这意味着我们仍然受限于 Python API 的速度。如果我们传入 Arrow 列或 NumPy 整数数组,我们可能会运行得更快。

更广泛地说,如果我们想要更快,还应该考虑使用不同的算法。有许多近似计数算法,我选择这个算法是因为它简单易懂,而不是因为它特别快。

大局观:为什么选择 Rust?

Rust 允许我们通过访问编译语言来加速代码。但 C、C++ 或 Cython 也是如此。

与这些语言不同,Rust 具有现代工具链,内置包管理器和构建系统:

  • 添加新依赖项就像 cargo add <cratename> 一样简单。浏览可用 crate 的好网站是 lib.rs
  • 不太明显但仍然重要的是:这段代码将在 macOS 和 Windows 上编译,无需额外工作。相比之下,跨平台管理 C 或 C++ 依赖项可能非常痛苦。

Rust 还具有出色的 Python 集成,可以轻松访问 Python API 并使用像 Maturin 这样的打包工具。

Rust 既可以扩展到小型项目(如本文中的项目),也可以扩展到更大的项目,这要归功于其内存和线程安全性以及强大的类型系统。Polars 项目有一个 330K 行的通用 Rust 核心,然后用更多的 Rust 和 Python 包装成一个 Python 扩展。

下次你考虑加速 Python 代码时,试试 Rust。它不是一门简单的语言,但值得你花时间学习。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

李星星BruceL

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值