Python多进程间通讯(包含共享内存方式)


注:本博文测试环境为Linux系统。


1 通过非共享内存配合队列方式

下面是一个常见的生产者与消费者的模式示例,这里分别启动了两个子进程,一个为生产者(producer)一个为消费者(consumer)。生产者负责生产Numpy的NDArray数据(这里为了体现进程间传递数据会耗时故创建的NDArray的shape比较大),然后将数据放入队列Queue。消费者监控队列Queue一旦有数据就取出并简单打印下shape信息和填充的Value信息。

import time
import multiprocessing as mp
from multiprocessing import Process, Queue

import numpy as np


def producer_task(queue: Queue):
    for i in range(10):
        data = np.full(shape=(1, 3, 2048, 2048), fill_value=i, dtype=np.float32)
        queue.put(data)
        time.sleep(0.1)

    # send exit signal
    queue.put(None)
    print("producer exit.")


def consumer_task(queue: Queue):
    while True:
        data = queue.get()
        if data is None:
            break

        print(f"get data shape:{data.shape}, fill value:{data[0][0][0][0]}")

    print("consumer exit.")


def main():
    queue = Queue()
    producer = Process(target=producer_task, args=(queue,), name="producer")
    consumer = Process(target=consumer_task, args=(queue,), name="consumer")

    producer.start()
    consumer.start()

    producer.join()
    consumer.join()


if __name__ == '__main__':
    mp.set_start_method("spawn")
    main()

执行以上代码终端输出以下内容:

get data shape:(1, 3, 2048, 2048), fill value:0.0
get data shape:(1, 3, 2048, 2048), fill value:1.0
get data shape:(1, 3, 2048, 2048), fill value:2.0
get data shape:(1, 3, 2048, 2048), fill value:3.0
get data shape:(1, 3, 2048, 2048), fill value:4.0
get data shape:(1, 3, 2048, 2048), fill value:5.0
get data shape:(1, 3, 2048, 2048), fill value:6.0
get data shape:(1, 3, 2048, 2048), fill value:7.0
get data shape:(1, 3, 2048, 2048), fill value:8.0
get data shape:(1, 3, 2048, 2048), fill value:9.0
producer exit.
consumer exit.

为了进一步看清进程之间传递数据的过程,这里使用viztracer工具进一步分析(直接通过pip install viztraver即可安装)。使用指令如下,其中main.py就是上面的代码内容。跑完后会在当前目录下生成一个result.json文件。

viztracer main.py

通过如下指令可视化result.json文件:

vizviewer result.json

在终端输入上述指令后,终端会提示你打开网页并进入http://localhost:9001,如果使用的是VSCODE IDE在右下角也会提示你打开浏览器。
在这里插入图片描述

在这里插入图片描述
可以看到生产者进程在将数据放入队列后会先进行ForkingPickler.dump即数据序列化的过程,大概耗时12ms。然后开始posix.write即开始将数据从一个进程传递到另一个进程,大概耗时34ms。最后在消费者进程进行_pickle.loads即数据的反序列化,大概耗时6ms。从生产者进程将数据放入队列到消费者进程拿到数据总耗时约53ms。从这个示例中可以看到,当在进程间传递的数据量很大时会很耗时。


2 通过共享内存配合队列方式

下面示例代码将传递的数据改为了共享内存的方式,这样可以大幅减小进程间数据传递的成本。这里主要是使用multiprocessing库中的shared_memory.SharedMemory对象。创建新的共享内存时需要将create参数设置为True(如果是复用已有的共享内存时设置为False),然后指定具体的size大小,该参数为数据的字节大小,比如要申请一块存放数据类型为float32shape(1, 3, 2048, 2048)的空间所需字节数为1 * 3 * 2048 * 2048 * 4float32为4个字节)。根据Python官方文档介绍,当一个进程不在使用该共享内存时应关闭指向共享内存的文件描述符/句柄,具体做法是调用共享内存对象的close方法。当某块共享内存不在需要时,需在最后一个使用到的进程中调用unlink方法显示释放掉(如果不调用该方法,共享内存会一直存在,如果后续再不断申请新的共享内存则会出现共享内存泄露的问题,或者当程序未正常退出时该共享内存块会成为僵尸共享内存?)。例如在当前示例中,生产者进程创建了共享内存并放入队列里后可调用close方法关闭当前进程指向共享内存的文件描述符/句柄,在消费者进程中拿到数据并消费完后除了调用close方法外还会调用unlink方法删除该共享内存。有关共享内存的详细介绍看查看Python官方文档:
https://docs.python.org/zh-cn/3/library/multiprocessing.shared_memory.html#multiprocessing.shared_memory.SharedMemory

import time
import multiprocessing as mp
from multiprocessing import Process, Queue, shared_memory

import numpy as np


def producer_task(queue: Queue):
    for i in range(10):
        shm = shared_memory.SharedMemory(
            name=f"data_{i}",
            create=True,
            size=1 * 3 * 2048 * 2048 * 4
        )
        np_data = np.ndarray(shape=(1, 3, 2048, 2048), dtype=np.float32, buffer=shm.buf)
        np_data.fill(i)

        queue.put(shm.name)
        shm.close()
        time.sleep(0.1)

    # send exit signal
    queue.put(None)
    print("producer exit.")


def consumer_task(queue: Queue):
    while True:
        shm_name = queue.get()
        if shm_name is None:
            break

        shm = shared_memory.SharedMemory(name=shm_name, create=False)
        np_data = np.ndarray(shape=(1, 3, 2048, 2048), dtype=np.float32, buffer=shm.buf)
        print(f"get data shape:{np_data.shape}, fill value:{np_data[0][0][0][0]}")
        shm.close()
        shm.unlink()

    print("consumer exit.")


def main():
    queue = Queue()
    producer = Process(target=producer_task, args=(queue,), name="producer")
    consumer = Process(target=consumer_task, args=(queue,), name="consumer")

    producer.start()
    consumer.start()

    producer.join()
    consumer.join()


if __name__ == '__main__':
    mp.set_start_method("spawn")
    main()

同样我们使用viztracer来看看进程间的通讯情况:
在这里插入图片描述

数据从生产者进程传递到消费者进程耗时为245us相比之前不使用共享内存方法的53ms,速度比值为53000/245≈216X,提升还是非常明显的。但是这有个很奇怪的现象我无法理解,就是在生产者进程中调用close方法用了1.8ms,而在消费者进程里调用close方法只用了15us,unlink用了8us,如果有知道的大神希望能帮忙解释下。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

太阳花的小绿豆

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值