目录
当 NumPy 太慢时
如果你在进行数值计算,NumPy 比普通的 Python 快得多——但有时这还不够。当你的基于 NumPy 的代码太慢时,你应该怎么做?
你的第一反应可能是并行化,但这可能是你最后才应该考虑的事情。在并行化变得有用之前,你可以通过算法改进或绕过 NumPy 的架构限制来进行许多优化。
让我们看看为什么 NumPy 可能会变慢,然后讨论一些可以帮助你进一步加速代码的解决方案。
第一步:在优化之前,选择一个可扩展的算法
在花太多时间考虑如何加速你的 NumPy 代码之前,确保你选择了一个可扩展的算法。
一个 O(N)
的算法比 O(N²)
的算法扩展性更好;随着 N
的增长,后者很快就会变得不可用,即使使用快速的实现。例如,我见过一些实际案例,其中在排序的 Python 列表上进行二分查找比在 C 语言中进行线性查找要快得多。
因此,一旦你处理大量数据——几乎可以肯定 NumPy 变慢的情况——选择一个好的算法对速度至关重要。你不希望浪费时间优化一个以后需要替换的算法。
NumPy 的固有性能限制
一旦你有了一个好的算法,你就可以开始考虑优化。专注于单核执行,没有并行化的情况下,NumPy 代码可能不够快的三个特定原因如下:
瓶颈 1:即时执行
考虑以下代码:
def add_multiple(arr):
result = arr + 17
result += 17
result += 17
result += 17
result += 17
result += 17
result += 17
result += 17
result += 17
return result
显然,反复添加这些数字是愚蠢的。一个更快的实现如下:
def add_multiple_efficient(arr):
return arr + (17 * 9)
然而,NumPy 无法 自动将第一个函数转换为第二个函数,因为 NumPy 对整体执行一无所知。NumPy 会逐条执行你给出的每条语句,无法预知未来的语句。
换句话说,单个操作可能很快,但没有机制可以自动优化一系列操作……即使优化非常明显。
瓶颈 2:通用的编译代码
不同的 CPU 有不同的能力,例如不同的 SIMD(单指令多数据)操作集,这些操作可以显著加速数值代码。然而,你从 PyPI 或 Conda 下载的 NumPy 包——或者更广泛地说,任何预编译的包——必须能在你安装的任何随机计算机上运行。
NumPy 确实提供了越来越多操作的手动 SIMD 实现,但这必须逐个操作添加。编译器有时可以自动生成 SIMD 指令,但它将受限于最低的共同标准;现代 CPU 指令不会被使用。
更广泛地说,不同的 CPU 对不同操作的成本不同。由于你下载的 NumPy 编译版本需要在不同的 CPU 上运行,它将使用一个广泛适用的成本模型进行编译,但永远不会完全匹配你 CPU 的具体特性。
瓶颈 3:向量化带来的更高内存使用
考虑以下函数:
import numpy as np
def mean_distance_from_zero(arr):
return np.abs(arr).mean()
这将创建一个绝对值的临时数组,如果我们能够有效地使用 for
循环,这个数组是不必要的。虽然这增加了内存使用,但它也对 CPU 有影响:CPU 缓存无法像不同值需要被推入和逐出时那样高效地使用。
第二步:优化你的 NumPy 代码
如何解决这三个瓶颈?你有多种选择:
选项 1:手动优化 NumPy 代码
你可以使用多种技术。
重写你的 NumPy 代码以提高效率
即使有一个可扩展的算法,也可能有方法通过消除冗余工作或重构代码来提高效率。
例如,注意上面两个版本的 add_multiple()
都是 O(N)
:它们随着输入大小的增加而线性扩展。输入数组长度加倍,运行时间也会加倍。但即使它们的扩展性相似,一个实现比另一个快得多,因为它对每个条目做的工作更少。
使用现有的原生代码函数
在许多情况下,你需要的操作已经在 NumPy 或其他库(如 SciPy、Scikit-Image 等)中以高效的原生代码实现。这些实现通常比你自己的实现运行得更快。
例如,我们可以自己实现一个方差计算:
def myvar(arr):
mean = arr.mean()
diff = arr - arr.mean()
diff **= 2
return diff.sum() / arr.size
或者我们可以直接使用 numpy.var()
。事实证明,numpy.var()
在我的电脑上快了约 1.25 倍。
值得熟悉 NumPy API、SciPy API 和其他相关库,以便你知道哪些操作已经可用。
重新编译 NumPy 及其相关库以针对你的特定 CPU
可以重新编译 NumPy 和其他库,使它们专门针对你的 CPU 进行编译。在某些情况下,这会加速你计算机上的操作,因为编译器将能够使用你特定 CPU 上的所有可用指令。实际上,这可能非常痛苦,对大多数人来说,大多数时候不值得这样做。
选项 2:使用 JAX 进行自动加速的即时编译
你可以使用 JAX 库来自动使用即时编译来优化使用 NumPy 的函数。简而言之,jax.jit()
将获取一个使用 NumPy API 的函数,并在运行时将其编译为原生代码。它还有一个低级 API,可以给你更多控制。
- 与 NumPy 的一次一个操作的即时执行不同,JAX 的 JIT 分析整个函数,因此它可以跨操作优化代码。
- 它即时生成代码,因此理论上可以针对你的本地 CPU 进行定制。
- 它可能会决定使用多个 CPU。
例如:
from jax import jit
@jit
def add_multiple(arr):
result = arr + 17
result += 17
result += 17
result += 17
result += 17
result += 17
result += 17
result += 17
result += 17
return result
这个函数比原始的 NumPy 等效函数快 4 倍,并且由于使用了多个 CPU,返回结果的速度快了 12 倍。
当然,这个例子很傻,因为没有理由编写像上面那样的代码。但 JAX 也可以优化那些不太明显如何手动提高效率的 NumPy 操作。
选项 3:使用 Numba 即时编译重写代码
有些操作很难表达为全数组操作,或者至少内存效率低下,你真的希望使用 for
循环来实现它们。JAX 通过其低级 API 提供了一些支持,但另一个即时编译解决方案是 Numba,它给你更多控制权,同时仍然让你受益于针对本地 CPU 的编译器。特别是,你可以直接使用 Python 的一个子集来编写带有 for
循环的代码,然后将其编译为原生代码。
例如,这个方差实现比 NumPy 的快 3 倍:
from numba import njit
@njit
def myvar2(arr):
mean = arr.mean()
sum = np.float64(0.0)
for item in arr:
sum += (np.float64(item) - mean) ** 2
return sum / arr.size
选项 4:提前编译
Numba 允许你编写带有 for
循环的代码,并在运行时编译。另一种选择是使用 Cython、Rust、Fortran 或其他编译语言 来编写算法的编译版本。与 Numba 一样,使用编译器可以让你获得自动优化,并且使用 for
循环可以提高内存效率。
与 Numba 不同,代码需要提前编译。因此,默认情况下,编译的扩展将针对通用 CPU,而不是你的特定 CPU。你可以让编译器针对你的 CPU 进行编译,但与具有不同 CPU 的其他计算机共享编译后的代码可能会很困难。
比较选项
每种选项如何解决我们考虑的三种瓶颈?
技术提供: | 跨操作优化 | CPU 特定代码 | 内存效率 |
---|---|---|---|
手动优化 | 有时 | 有时 | 有时 |
JAX | 是(在函数内) | 是 | 不清楚 |
Numba | 是(在函数内) | 是 | 是 |
Cython/Rust/Fortran/C | 是(在函数内) | 需要额外工作 | 是 |
第三步:考虑并行化
正如我在其他地方详细讨论的那样,你应该在考虑并行化之前专注于优化单核性能。更快的算法可以带来巨大的加速,除此之外,使用像 Numba 这样的工具可以使你的代码更高效,甚至可能获得额外的 3 倍、10 倍,甚至25 倍的加速(即快 2400%!)。
一旦你达到了单核加速的极限,你可以考虑并行化。这里提到的许多工具都有能力使用多核:如果你选择使用 Numba,Rust 如果你使用 Rayon,JAX 的 JIT 可能会自动使用多核。但如果你直接跳到并行化,你可能会忽略一个数量级(或更多!)的潜在加速。
总结
当 NumPy 代码变慢时,首先确保你选择了合适的算法,然后考虑手动优化、使用 JAX 或 Numba 进行即时编译,或者使用 Cython、Rust 等语言进行提前编译。最后,在单核优化达到极限时,再考虑并行化。