C++线程池实战(2)--任务窃取

前沿:第一篇文章主要介绍了C++线程池的最基础的写法,虽然基础但十分重要,后续所有相关的项目思路都是基于该线程池的实现方式;为了方便演示,上文将线程池的实现放在了构造函数中,在本文以及以后的文章中,会将该部分代码作为类的私有工具函数,不再作为可以被外部捕捉到的接口:该部分是线程池的核心工作逻辑,属于类的内部实现细节,不应该暴露给外部,并且使用者无需关心线程如何循环、调度,降低了使用的复杂度,也符合封装原则。

本节内容是实现一个支持任务窃取的线程池

2.1是什么是任务窃取?

任务窃取是并行计算或线程池中的一种负载均衡技术,用于解决多线程任务分配不均的问题。其核心逻辑是:当一个工作线程(Worker Thread)的本地任务队列为空时,它会主动去其他工作线程的任务队列中 “窃取” 任务来执行,避免自身空闲。这能让系统资源(线程)得到更充分的利用,提高整体执行效率。

任务窃取通常用于支持嵌套任务或动态任务分配的场景(例如并行递归、分治算法等),常见于高性能线程池实现中(如 C++ 的std::async底层可能用到类似思想,或一些第三方线程池库)。

在传统线程池中,任务通常被 “静态分配”(比如按顺序分给每个线程,或随机分配),但这种方式存在明显缺陷:

  • 不同任务的执行时间差异很大(比如有的任务耗时 1ms,有的耗时 1000ms);
  • 动态任务生成场景(比如任务执行中会创建新的子任务,如并行递归、分治算法)。

这两种情况都会导致负载不均:有的线程早早执行完所有任务,进入空闲状态;而有的线程还堆积了大量任务,持续忙碌。此时 CPU 资源被浪费(空闲线程的核心利用率为 0),整体执行效率大幅下降。

总结就是:任务窃取就是为解决这个问题而生 —— 让空闲线程 “主动找活干”,平衡各线程的负载。

2.2任务窃取的特点

  1. 无中心化竞争:每个线程优先处理自己的本地队列,只有空闲时才去 “窃取”,避免了所有线程争抢同一个 “全局任务队列” 的高频锁竞争,性能更高。
  2. 动态负载均衡:不需要提前预测任务耗时,也不需要中央调度器实时监控负载 —— 空闲线程主动找活干,动态平衡各线程的任务量,最大化 CPU 利用率。
  3. 适配动态任务场景:特别适合 “任务会生成子任务” 的场景(如并行归并排序、并行斐波那契计算、异步 IO 回调链)。子任务直接放入当前线程的本地队列,后续可被其他空闲线程窃取,天然适配动态任务流。

2.3代码展示如下所示

#include <cstddef>
#include <iostream>
#include <atomic>
#include <functional>
#include <future>
#include <memory>
#include <stdexcept>
#include <type_traits>
#include <mutex>
#include <thread>
#include <chrono>
#include <condition_variable>
#include <utility>
#include <vector>
#include <queue>
#include <random>
using namespace std;

//支持窃取任务的线程池
class ThreadPool
{  
public:
    explicit ThreadPool(size_t num_threads) : stop_flag(false)
    {
        if(num_threads == 0) num_threads = 1;

        workers_steal.resize(num_threads);//初始化窃取任务队列大小
        for(int i = 0; i < num_threads; i++)
        {
            workers.emplace_back([this, i](){worker_loop(i);});
        }
    }

    template<typename F, typename... Args>
    auto submit(F&&f, Args&&... args) -> std::future<typename std::invoke_result<F, Args...>::type>
    {
        using ReturnType = typename std::invoke_result<F, Args...>::type;
        auto task = std::make_shared<std::packaged_task<ReturnType()>>(std::bind(std::forward<F>(f), std::forward<Args>(args)...));
        std::future<ReturnType> res = task->get_future();
        {
            std::lock_guard<std::mutex> lock(queue_mutex);
            if(this->stop_flag.load()) throw runtime_error("submit on a stopped threadpool");
            tasks.emplace([task]{(*task)();});
        }
        cv_queue.notify_one();
        return res;
    }

    template<typename F, typename... Args>//默认参数不能位于可变参数之后
    auto submit_with_steal(int preferred_worker, F&&f, Args&&... args) -> std::future<typename std::invoke_result<F, Args...>::type>
    {
        using ReturnType = typename std::invoke_result<F, Args...>::type;
        auto task = std::make_shared<std::packaged_task<ReturnType()>>(std::bind(std::forward<F>(f), std::forward<Args>(args)...));
        std::future<ReturnType> res = task->get_future();

        static thread_local  random_device rd;
        static thread_local mt19937 gen(rd());
        uniform_int_distribution<size_t> dist(0, workers_steal.size()-1);
        int target_worker = preferred_worker;
        if(preferred_worker >= workers_steal.size()) target_worker = dist(gen);//如果指定的偏好线程大于窃取任务容器的容量,则将target_worker置为任意数
        {
            std::lock_guard<std::mutex> lock(steal_mutex);
            if(this->stop_flag.load()) throw runtime_error("submit on a stopped threadpool");
            workers_steal[target_worker].emplace([task]{(*task)();});//将任务存储到workers_steal的指定任务
        }
        cv_steal.notify_one();
        return res;
    }

    template<typename F, typename... Args>//默认参数不能位于可变参数之后
    auto submit_with_steal(F&&f, Args&&... args) -> std::future<typename std::invoke_result<F, Args...>::type>
    {
        return submit_with_steal(0, std::forward<F>(f), std::forward<Args>(args)...);
    }
    ~ThreadPool()
    {
        this->stop_flag.store(true);
        cv_queue.notify_all();
        cv_steal.notify_all();
        for(auto&f : workers)
        {
            if(f.joinable()) f.join();
        }
    }
private:
    //工作线程所需数据结构
    std::atomic<bool> stop_flag;
    std::mutex queue_mutex;
    std::condition_variable cv_queue;
    std::vector<std::thread> workers;
    std::queue<std::function<void()>> tasks;
    //窃取任务所需数据结构
    std::mutex steal_mutex;
    std::vector<std::queue<std::function<void()>>> workers_steal;//存储每个线程的私有队列
    std::condition_variable cv_steal;
    //工作线程主循环
    void worker_loop(int worker_id)
    {
        while(true)
        {
            std::function<void()> task;
            bool has_task = false;//*
            //阶段一:优先执行普通任务
            {
                std::unique_lock<std::mutex> lock(queue_mutex);
                //这里一定要是wait_for,
                //如果使用wait,调用窃取任务时tasks为空,wait就会阻塞到这里无法执行
                //使用wait_for的话,就算tasks为空,但是等到10ms之后就会自动向下执行代码
                cv_queue.wait_for(lock, chrono::milliseconds(10), [this](){return this->stop_flag.load() || !tasks.empty();});
                //if(this->stop_flag.load() && tasks.empty()) throw runtime_error("submit on a stopped threadpool");
                if(!tasks.empty())
                {
                    task = std::move(tasks.front());
                    tasks.pop();
                    has_task = true;
                }
                
            }
            //阶段二:执行自己的窃取任务
            if(!has_task && !this->stop_flag.load())
            {
                std::unique_lock<std::mutex> lock(steal_mutex);
                cv_steal.wait_for(lock, chrono::milliseconds(10), [this, worker_id]{return stop_flag.load() || !workers_steal[worker_id].empty();});
                if( !workers_steal[worker_id].empty())
                {
                    task = std::move(workers_steal[worker_id].front());
                    workers_steal[worker_id].pop();
                    has_task = true;
                }
            }
            //阶段三:若自己对应的窃取任务为空,就窃取别人的任务
            if(!has_task && !this->stop_flag.load())
            {
                std::unique_lock<std::mutex> lock(steal_mutex);
                for(int j = 0; j < workers_steal.size(); j++)
                {
                    if(j == worker_id) continue;//跳过自己
                    if(workers_steal[j].empty()) continue; // 跳过空队列
                    task = std::move(workers_steal[j].front());
                    workers_steal[j].pop();
                    has_task = true;
                    break;
                }
            }
            //执行任务
            if(has_task && task)
            {
                task();
            }
            //检查所有任务,实现退出while死循环
            if(stop_flag.load())
            {
                bool all_empty = true;
                //检查工作线程
                { 
                    std::lock_guard<std::mutex> lock(queue_mutex);
                    all_empty = all_empty && tasks.empty();
                }
                //检查窃取队列
                {
                    std::lock_guard<std::mutex> lock(steal_mutex);
                    for(auto& f : workers_steal)
                    {
                        if(!f.empty())
                        {
                            all_empty = false;
                            break;
                        }
                    }
                }
                if(all_empty) break;
            }
        }
    }
};

2.3.1测试普通任务

//测试普通任务
void test01()
{
    std::mutex mtx;//确保顺序输出
    ThreadPool pool(std::thread::hardware_concurrency() ? std::thread::hardware_concurrency() : 4);
    vector<std::future<int>> res;
    for(int i = 0; i < 10; i++)
    {
        res.emplace_back(pool.submit([i, &mtx]()
        {
            std::lock_guard<std::mutex> lock(mtx);
            cout << "task #:" << i << " is running in thread " << std::this_thread::get_id() << endl;
            std::this_thread::sleep_for(std::chrono::milliseconds(10ms));
            return i * i;
        }));
    }
    std::this_thread::sleep_for(std::chrono::milliseconds(3s));
    for(int i = 0; i < res.size(); i++)
    {
        cout << "res: " << res[i].get() << endl;
    }
    
}   

输出结果如下所示:

2.3.2测试窃取任务

//测试窃取任务
void test02()
{
    std::mutex mtx1;//确保顺序输出
    ThreadPool pool(std::thread::hardware_concurrency() ? std::thread::hardware_concurrency() : 4);
    vector<std::future<int>> res;
    for(int i = 0; i < 10; i++)
    {
        res.emplace_back(pool.submit_with_steal(0,[i, &mtx1]
            {
                std::lock_guard<std::mutex> lock(mtx1);
                cout << "steal task#:" << i << " is running in thread " << std::this_thread::get_id() << endl;
                std::this_thread::sleep_for(std::chrono::milliseconds(50));
                return i * i * i;
            }));
    }
    std::this_thread::sleep_for(std::chrono::milliseconds(3s));
    for(auto&f : res)
    {
        cout << "steal res: " << f.get() << endl;
    }
}

输出结果如下所示:

2.4总结--可以先看这里

与第一节对比,可发现窃取任务多了几个数据结构,来支撑完成窃取任务。同时,在线程池类中多了一个submit_with_steal的对外接口,执行普通任务和窃取任务是两个不同的接口,我们可以在程序中定义执行的优先级,即,worker_loop函数中,将普通任务放到前面就先执行普通任务,反过来道理是一样的。

存储普通任务的数据结构与存储窃取任务的数据结构是两个不同的结构,二者不会存在关联关系(如果都放在一起,那就不是两个任务了)。读到这里会发现,代码中提供了两个接口,不像是窃取任务(二者之间可以说是没有什么关联性),更像是多任务队列分配,窃取任务本质应该是在同一个队列中执行任务,首先恭喜大家发现了问题的本质,在worker_loop中阶段二和阶段三,简单了在一个队列中实现了真正的任务窃取。我自己没有任务可以执行了,就跑去看看谁的任务在还没有被执行,就帮忙给它的执行掉。

下一节会实现一个关联性强的真正的窃取任务,大家可以思考一下使用哪种数据结构可以实现我们的目标。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值