目录
Python的多进程性能问题
由于Python在使用线程时存在并行性的限制,使用工作进程是利用多核CPU的常见方式。Python标准库中的multiprocessing
模块经常被用于此目的。
然而,尽管多进程允许你利用多个CPU,但在进程之间移动数据可能会非常慢。这可能会减少使用工作进程带来的性能优势。
让我们来看看:
- 为什么进程会有线程没有的性能问题。
- 一些解决或应对这种性能开销的方法。
- 一个你不应该使用的糟糕解决方案。
线程 vs. 进程
多线程允许你并行运行代码,可能是在多个CPU上。然而,在Python中,全局解释器锁(GIL)使得这种并行性难以实现。
多进程也允许你并行运行代码——那么线程和进程之间有什么区别呢?
同一个进程中的所有线程共享相同的内存地址空间。 如果进程中的线程1在地址0x7f0cd1a88810
存储了一些内存,线程2可以在相同的地址访问相同的内存。这意味着在线程之间传递对象是廉价的:你只需要将一个线程中的内存地址指针传递给另一个线程。一个内存地址只有8字节:这不是很多数据。
相比之下,进程不共享相同的内存空间。 操作系统通常提供一些共享内存设施,我们稍后会讨论。但默认情况下,没有内存是共享的。这意味着你不能简单地在进程之间共享数据的地址:你必须复制数据。
如果你在进程之间传递少量数据,这没问题;但如果你传递一个1GB的DataFrame
……那可能会变得非常昂贵。
Python中的多进程
到目前为止,我们讨论的是操作系统层面的进程,操作系统提供的设施基本上涉及字节的复制:从文件、共享内存,或者像mmap()
这样的混合方式。然而,当你编写Python代码时,你希望在进程之间共享Python对象。
为了实现这一点,当你使用Python的multiprocessing
库在进程之间传递Python对象时:
- 在发送端,参数会被
pickle
模块序列化为字节。 - 在接收端,字节会被
pickle
反序列化。
这种序列化和反序列化过程涉及计算,可能会很慢。让我们来看一个例子,比较线程池和进程池:
from time import time
import multiprocessing as mp
from multiprocessing.pool import ThreadPool
import numpy as np
import pickle
def main():
arr = np.ones((1024, 1024, 1024), dtype=np.uint8)
expected_sum = np.sum(arr)
with ThreadPool(1) as threadpool:
start = time()
assert (threadpool.apply(np.sum, (arr,)) == expected_sum)
print("Thread pool:", time() - start)
with mp.get_context("spawn").Pool(1) as processpool:
start = time()
assert (processpool.apply(np.sum, (arr,)) == expected_sum
print("Process pool:", time() - start)
if __name__ == "__main__":
main()
运行这个代码,我们得到以下结果:
$ python threads_vs_processes.py
Thread pool: 0.3097844123840332
Process pool: 1.8011224269866943
在子进程中运行代码比在线程中运行要慢得多,这并不是因为计算本身变慢了,而是因为复制和(反)序列化数据的开销。那么如何避免这种开销呢?
减少进程间复制数据的性能影响
选项 #1: 直接使用线程
进程有这种开销,而线程没有。虽然通用的Python代码在使用多线程时不会很好地并行化,但这并不一定适用于你的Python代码。例如,NumPy在许多操作中会释放GIL,这意味着即使使用线程,你也可以使用多个CPU核心。
例如:
import numpy as np
from time import time
from multiprocessing.pool import ThreadPool
arr = np.ones((1024, 1024, 1024))
start = time()
for i in range(10):
arr.sum()
print("Sequential:", time() - start)
expected = arr.sum()
start = time()
with ThreadPool(4) as pool:
result = pool.map(np.sum, [arr] * 10)
assert result == [expected] * 10
print("4 threads:", time() - start)
运行时,我们看到NumPy在使用线程时能够很好地利用多核,至少对于这个操作来说:
$ python numpy_gil.py
Sequential: 4.253053188323975
4 threads: 1.3854241371154785
在可以使用Python线程实现并行性的情况下,比如使用NumPy的API,使用进程的动机就小得多。
Pandas是基于NumPy构建的,因此许多数值操作可能会释放GIL。然而,任何涉及字符串或Python对象的操作都不会。因此,另一种方法是使用像Polars这样的库,它从一开始就设计为并行化,你甚至不需要考虑它,它有一个内部线程池。
选项 #2: 接受它
如果你不得不使用进程,你可能会决定接受pickle的开销。特别是,如果你尽量减少进程之间传递的数据量,并且每个进程中的计算足够显著,复制和序列化数据的成本可能不会显著影响程序的运行时间。如果后续计算需要10分钟,花费几秒钟在pickle上并不重要。
此外,值得注意的是,Python有一个更快的pickle版本,截至3.11版本尚未默认启用;它可能会在未来的版本中启用。这将减少pickle的开销,尽管它仍然存在。
选项 #3: 将数据写入磁盘
你可以将数据写入磁盘,然后将文件路径传递给子进程(作为参数)或父进程(作为工作进程中运行的函数的返回值)。接收进程可以解析该文件。
以下是一个比较直接传递DataFrame和使用临时Parquet文件传递DataFrame的示例:
import pandas as pd
import multiprocessing as mp
from pathlib import Path
from tempfile import mkdtemp
from time import time
def noop(df: pd.DataFrame):
# 实际代码会在这里处理DataFrame
pass
def noop_from_path(path: Path):
df = pd.read_parquet(path, engine="fastparquet")
# 实际代码会在这里处理DataFrame
pass
def main():
df = pd.DataFrame({"column": list(range(10_000_000))})
with mp.get_context("spawn").Pool(1) as pool:
# 直接通过pickle将DataFrame传递给工作进程
start = time()
pool.apply(noop, (df,))
print("Pickling-based:", time() - start)
# 将DataFrame写入文件,将文件路径传递给工作进程
start = time()
path = Path(mkdtemp()) / "temp.parquet"
df.to_parquet(
path,
engine="fastparquet",
# 跳过压缩以加快速度
compression="uncompressed",
)
pool.apply(noop_from_path, (path,))
print("Parquet-based:", time() - start)
if __name__ == "__main__":
main()
运行它,我们可以看到Parquet版本确实更快:
$ python tofile.py
Pickling-based: 0.24182868003845215
Parquet-based: 0.17243456840515137
Parquet在所有情况下可能不一定更快,pickle在未来的Python版本中可能会运行得更快,但在某些情况下,这种方法可能会有帮助。
选项 #4: multiprocessing.shared_memory
因为进程有时确实希望共享内存,操作系统通常提供显式创建进程间共享内存的设施。Python将这些设施封装在multiprocessing.shared_memory
模块中。
然而,与线程不同,线程共享相同的内存地址空间,可以轻松共享Python对象,而在这里你主要限于共享数组。正如我们所看到的,NumPy在昂贵的操作中会释放GIL,这意味着你可以直接使用线程,这要简单得多。不过,如果你需要,知道这个模块的存在是值得的。
注意: 该模块还包括
ShareableList
,它有点像Python列表,但仅限于int
、float
、bool
、小str
和bytes
,以及None
。但这并不能帮助你廉价地共享任意Python对象。
Linux上的糟糕选项:"fork"
上下文
你可能已经注意到我们使用multiprocessing.get_context("spawn").Pool()
来创建进程池。这是因为Python在某些操作系统上有多种多进程实现。"spawn"
是Windows上的唯一选项,macOS上的唯一非破坏性选项,并且在Linux上也可用。使用"spawn"
时,会创建一个全新的进程,因此你总是需要跨进程复制数据。
在Linux上,默认是"fork"
:新的子进程在创建时拥有父进程内存的完整副本。这意味着在子进程创建之前,父进程中创建的任何对象(数组、巨大的字典等)如果存储在某个有用的地方(如模块中),子进程可以访问它们。这意味着你不需要pickle/unpickle来访问它们。
听起来很有用,对吧?只有一个问题:"fork"
上下文非常容易出问题,这就是为什么它将在Python 3.14中不再是默认选项。
考虑以下程序:
import threading
import sys
from multiprocessing import Process
def thread1():
for i in range(1000):
print("hello", file=sys.stderr)
threading.Thread(target=thread1).start()
def foo():
pass
Process(target=foo).start()
在我的电脑上,这个程序总是死锁:它冻结并且永远不会退出。任何时候你在父进程中有线程,"fork"
上下文都可能导致子进程中的潜在死锁,甚至内存损坏。
你可能会认为你没问题,因为你没有启动任何线程。但许多Python库在导入时会启动一个线程池,例如NumPy。如果你使用NumPy、Pandas或任何其他依赖NumPy的库,你正在运行一个多线程程序,因此在使用"fork"
多进程上下文时,存在死锁、段错误或数据损坏的风险。
所以理论上这是Linux上的一个选项,但实际上你真的不想使用它。因此,我不会向你展示如何以这种方式跨进程传递数据。如果你真的想知道,其他地方有文章演示了这一点,但如果你采用这种方法,你只是在自找麻烦。