Python多线程与GIL深度理解

目录

GIL的工作原理与影响

什么是GIL

GIL的释放机制

I/O密集型任务中的GIL表现

线程安全与锁机制

线程安全问题演示

各种锁机制详解

互斥锁(Mutex Lock)

条件变量(Condition)

信号量(Semaphore)

屏障(Barrier)

死锁检测与避免

锁顺序策略

锁超时策略

线程池与任务调度

自定义线程池实现

任务优先级调度

实战项目:高并发任务处理系统


GIL的工作原理与影响

什么是GIL

GIL是CPython解释器中的一个互斥锁,它确保同一时刻只有一个线程在执行Python字节码。

import threading
import time
import sys

def demonstrate_gil_impact():
    """演示GIL对CPU密集型任务的影响"""
    def cpu_intensive_task(name, iterations):
        """CPU密集型任务"""
        start_time = time.time()
        total = 0
        for i in range(iterations):
            total += i ** 2
        end_time = time.time()
        print(f"{name} 完成,耗时: {end_time - start_time:.2f}秒,结果: {total}")
        return total
    
    iterations = 10000000
    
    # 单线程执行
    print("=== 单线程执行 ===")
    start_time = time.time()
    cpu_intensive_task("单线程", iterations)
    cpu_intensive_task("单线程", iterations)
    single_thread_time = time.time() - start_time
    print(f"单线程总耗时: {single_thread_time:.2f}秒\n")
    
    # 多线程执行
    print("=== 多线程执行 ===")
    start_time = time.time()
    
    threads = []
    for i in range(2):
        thread = threading.Thread(
            target=cpu_intensive_task, 
            args=(f"线程{i+1}", iterations)
        )
        threads.append(thread)
        thread.start()
    
    for thread in threads:
        thread.join()
    
    multi_thread_time = time.time() - start_time
    print(f"多线程总耗时: {multi_thread_time:.2f}秒")
    print(f"性能提升: {single_thread_time / multi_thread_time:.2f}x")

demonstrate_gil_impact()

python 的标准实现 CPython 中存在一个称为全局解释器锁(GIL,Global Interpreter Lock)的机制。GIL 保证同一时刻只有一个线程在执行 Python 字节码,这对于 I/O 密集型任务影响较小,但对于 CPU 密集型任务则会成为性能瓶颈。单线程执行:两次任务顺序执行,耗时为 T1 秒。多线程执行:两个线程并发执行两次任务,耗时为 T2 秒。理想情况下,如果多线程能够充分利用多核 CPU,T2 应该接近 T1 / 2,即性能提升接近 2 倍。然而,由于 GIL 的存在,两个线程实际上并不能真正并行执行 Python 字节码,导致多线程执行时间并没有明显缩短,甚至可能比单线程更长。

GIL的释放机制

import threading
import time
import sys

class GILMonitor:
    """GIL释放监控器"""
    def __init__(self):
        self.switch_count = 0
        self.lock = threading.Lock()
    
    def monitor_thread_switches(self):
        """监控线程切换"""
        def worker(worker_id):
            local_switches = 0
            for i in range(1000000):
                # 执行一些计算
                _ = i ** 2
                
                # 每10000次操作检查一次线程切换
                if i % 10000 == 0:
                    current_thread = threading.current_thread().ident
                    with self.lock:
                        if hasattr(self, 'last_thread') and self.last_thread != current_thread:
                            local_switches += 1
                        self.last_thread = current_thread
            
            with self.lock:
                self.switch_count += local_switches    
            print(f"工作线程 {worker_id} 完成,本地切换次数: {local_switches}")
        
        # 创建多个工作线程
        threads = []
        for i in range(4):
            thread = threading.Thread(target=worker, args=(i,))
            threads.append(thread)
            thread.start()
        
        for thread in threads:
            thread.join()
        print(f"总线程切换次数: {self.switch_count}")

# 运行监控
monitor = GILMonitor()
monitor.monitor_thread_switches()

GIL 会导致线程频繁切换,影响程序性能。示例种执行大量计算任务,并统计线程切换次数,直观展示 GIL 导致的线程切换行为。

代码结构与实现:定义了一个 GILMonitor 类,包含:switch_count:全局线程切换计数器。lock:线程锁,用于保护共享变量的访问。monitor_thread_switches 方法:启动多个线程执行计算任务,并监控线程切换。worker 函数是每个线程执行的任务:

执行大量平方计算,模拟 CPU 密集型任务,每 10000 次循环检查一次当前线程 ID。通过比较当前线程 ID 和上一次记录的线程 ID,判断是否发生线程切换。记录本地线程切换次数,最后累加到全局计数器。启动 4 个线程并等待它们完成,最后输出总线程切换次数。

I/O密集型任务中的GIL表现

import time
import requests
from concurrent.futures import ThreadPoolExecutor

def io_intensive_comparison():
    """比较I/O密集型任务在不同并发模式下的表现"""
    
    urls = [
        'http://baidu.com',
        'https://baidu.com',
        'https://baidu.com',
        'https://baidu.com',
        'https://baidu.com'
    ]
    
    def fetch_url(url):
        """同步获取URL"""
        try:
            response = requests.get(url, timeout=10)
            return f"状态码: {response.status_code}, URL: {url}"
        except Exception as e:
            return f"错误: {str(e)}, URL: {url}"
    
    # 串行执行
    print("=== 串行执行 ===")
    start_time = time.time()
    for url in urls:
        result = fetch_url(url)
        print(result)
    serial_time = time.time() - start_time
    print(f"串行执行耗时: {serial_time:.2f}秒\n")
    
    # 多线程执行
    print("=== 多线程执行 ===")
    start_time = time.time()
    
    with ThreadPoolExecutor(max_workers=5) as executor:
        futures = [executor.submit(fetch_url, url) for url in urls]
        for future in futures:
            print(future.result())
    
    thread_time = time.time() - start_time
    print(f"多线程执行耗时: {thread_time:.2f}秒")
    print(f"性能提升: {serial_time / thread_time:.2f}x\n")

io_intensive_comparison()

串行执行:依次请求每个 URL,等待响应完成后再请求下一个,因为每个请求都必须等待网络响应完成,无法利用等待时间。多线程执行:使用 ThreadPoolExecutor 创建线程池,多个线程并发请求 URL。多线程执行显著缩短了总耗时,多个请求并发发起,等待时间重叠,提升了整体吞吐量。结论:Python 多线程适合 I/O 密集型任务,能够有效利用等待时间实现并发,提升程序性能。

线程安全与锁机制

线程安全问题演示

import threading
import time

class UnsafeCounter:
    """非线程安全的计数器"""
    def __init__(self):
        self.count = 0
    def increment(self):
        """非原子性的递增操作"""
        temp = self.count
        # 模拟处理时间,增加竞态条件发生概率
        time.sleep(0.0001)
        self.count = temp + 1
    
    def get_count(self):
        return self.count

class SafeCounter:
    """线程安全的计数器"""
    def __init__(self):
        self.count = 0
        self.lock = threading.Lock()
    
    def increment(self):
        """使用锁保护的递增操作"""
        with self.lock:
            temp = self.count
            time.sleep(0.0001)
            self.count = temp + 1
    
    def get_count(self):
        with self.lock:
            return self.count

def test_thread_safety():
    """测试线程安全性"""
    def worker(counter, iterations):
        """工作线程函数"""
        for _ in range(iterations):
            counter.increment()
    
    iterations = 100
    num_threads = 10
    expected_count = iterations * num_threads
    
    # 测试非线程安全版本
    print("=== 测试非线程安全计数器 ===")
    unsafe_counter = UnsafeCounter()
    
    threads = []
    for i in range(num_threads):
        thread = threading.Thread(target=worker, args=(unsafe_counter, iterations))
        threads.append(thread)
        thread.start()
    for thread in threads:
        thread.join()
    unsafe_result = unsafe_counter.get_count()
    print(f"期望结果: {expected_count}")
    print(f"实际结果: {unsafe_result}")
    print(f"数据丢失: {expected_count - unsafe_result}\n")
    
    # 测试线程安全版本
    print("=== 测试线程安全计数器 ===")
    safe_counter = SafeCounter()
    
    threads = []
    for i in range(num_threads):
        thread = threading.Thread(target=worker, args=(safe_counter, iterations))
        threads.append(thread)
        thread.start()
    for thread in threads:
        thread.join()
    
    safe_result = safe_counter.get_count()
    print(f"期望结果: {expected_count}")
    print(f"实际结果: {safe_result}")
    print(f"数据丢失: {expected_count - safe_result}")
test_thread_safety()

在多线程程序中,多个线程同时访问和修改共享数据时,若没有适当的同步机制,可能导致数据竞争和状态不一致,称为竞态条件(Race Condition)。这会导致程序出现错误结果,影响程序的正确性。

UnsafeCounter(非线程安全计数器)
该类的 increment 方法没有任何同步措施,执行递增操作时先读取当前值,经过一段延迟后再写回新值。多个线程同时执行时,可能读取到相同的旧值,导致部分递增操作丢失。

SafeCounter(线程安全计数器)
该类使用 threading.Lock 互斥锁保护递增操作,保证同一时刻只有一个线程能修改计数器,避免竞态条件。

测试流程:创建多个线程,每个线程执行固定次数的计数器递增。分别测试非线程安全和线程安全计数器的最终计数结果。期望结果为线程数乘以每个线程的递增次数。

结果分析:非线程安全计数器,实际计数通常小于期望值,说明存在数据丢失,竞态条件导致部分递增操作被覆盖。

线程安全计数器:实际计数与期望值一致,说明通过锁机制成功避免了竞态条件,保证了数据一致性。

各种锁机制详解

import threading
import time
import random
class LockMechanisms:
    """各种锁机制演示"""
    def __init__(self):
        self.mutex_lock = threading.Lock()
        self.rw_lock = threading.RLock()  # 可重入锁
        self.condition = threading.Condition()
        self.semaphore = threading.Semaphore(3)  # 信号量,允许3个线程同时访问
        self.event = threading.Event()
        self.barrier = threading.Barrier(3)  # 屏障,等待3个线程
        self.shared_data = []
        self.readers_count = 0
    
    def demonstrate_mutex(self):
        print("=== 互斥锁演示 ===")  
        def critical_section(thread_id):
            with self.mutex_lock:
                print(f"线程 {thread_id} 进入临界区")
                time.sleep(1)
                self.shared_data.append(thread_id)
                print(f"线程 {thread_id} 离开临界区")
        
        threads = []
        for i in range(3):
            thread = threading.Thread(target=critical_section, args=(i,))
            threads.append(thread)
            thread.start()
        for thread in threads:
            thread.join()
        print(f"共享数据: {self.shared_data}\n")
    
    def demonstrate_condition(self):
        print("=== 条件变量演示 ===")
        items = []
        def consumer(consumer_id):
            with self.condition:
                while len(items) == 0:
                    print(f"消费者 {consumer_id} 等待商品")
                    self.condition.wait()
                item = items.pop(0)
                print(f"消费者 {consumer_id} 消费了 {item}")    
        def producer():
            for i in range(5):
                time.sleep(0.5)
                with self.condition:
                    item = f"商品{i}"
                    items.append(item)
                    print(f"生产者生产了 {item}")
                    self.condition.notify()
        # 启动消费者
        consumers = []
        for i in range(2):
            consumer_thread = threading.Thread(target=consumer, args=(i,))
            consumers.append(consumer_thread)
            consumer_thread.start()
        # 启动生产者
        producer_thread = threading.Thread(target=producer)
        producer_thread.start()
        producer_thread.join()
        for consumer in consumers:
            consumer.join()
        print()
    
    def demonstrate_semaphore(self):
        print("=== 信号量演示 ===")
        def access_resource(thread_id):
            with self.semaphore:
                print(f"线程 {thread_id} 获得资源访问权")
                time.sleep(2)
                print(f"线程 {thread_id} 释放资源")
        
        threads = []
        for i in range(6):  # 6个线程竞争3个资源
            thread = threading.Thread(target=access_resource, args=(i,))
            threads.append(thread)
            thread.start()
        for thread in threads:
            thread.join()
        print()
    
    def demonstrate_barrier(self):
        print("=== 屏障演示 ===")
        def worker(worker_id):
            print(f"工作线程 {worker_id} 开始工作")
            time.sleep(random.uniform(1, 3))
            print(f"工作线程 {worker_id} 完成工作,等待其他线程")
            
            self.barrier.wait()  # 等待所有线程到达屏障
            print(f"工作线程 {worker_id} 继续执行后续任务")
        
        threads = []
        for i in range(3):
            thread = threading.Thread(target=worker, args=(i,))
            threads.append(thread)
            thread.start()
        for thread in threads:
            thread.join()
# 运行锁机制演示
lock_demo = LockMechanisms()
lock_demo.demonstrate_mutex()
lock_demo.demonstrate_condition()
lock_demo.demonstrate_semaphore()
lock_demo.demonstrate_barrier()

互斥锁(Mutex Lock)

作用:保证同一时刻只有一个线程进入临界区,防止共享资源被并发修改导致数据不一致。示例说明:三个线程依次进入临界区,模拟对共享数据的访问和修改,确保操作的原子性。关键点:使用 with self.mutex_lock: 自动加锁和释放锁,简洁安全

条件变量(Condition)

作用:用于线程间的等待和通知机制,适合生产者-消费者模型。示例说明:消费者线程在商品列表为空时等待(condition.wait())。生产者线程生产商品后通知消费者(condition.notify())。关键点:条件变量必须与锁配合使用,保证等待和通知的同步

信号量(Semaphore)

作用:控制同时访问某个资源的线程数量,适合有限资源的访问控制。示例说明:6个线程竞争3个资源,最多3个线程同时进入临界区访问资源。关键点with self.semaphore: 自动申请和释放信号量,控制并发度。

屏障(Barrier)

作用:使一组线程在某个点上等待,直到所有线程都达到该点后再继续执行,适合阶段性同步。示例说明:3个工作线程各自完成任务后等待,全部完成后同时继续后续操作。关键点self.barrier.wait() 阻塞线程直到所有线程都调用该方法

总结:互斥锁适合保护共享资源,防止数据竞争。条件变量实现线程间的等待和通知,适合生产者-消费者模式。信号量限制并发访问资源的线程数,适合有限资源管理。屏障实现线程阶段性同步,确保多线程协作的有序进行。

死锁检测与避免

import threading
import time
from contextlib import contextmanager

class DeadlockDemo:
    """死锁演示和避免"""
    def __init__(self):
        self.lock1 = threading.Lock()
        self.lock2 = threading.Lock()
        self.ordered_locks = [self.lock1, self.lock2]  # 有序锁列表
    
    def demonstrate_deadlock(self):
        print("=== 死锁演示 ===") 
        def task1():
            print("任务1: 尝试获取锁1")
            with self.lock1:
                print("任务1: 获得锁1")
                time.sleep(1)
                print("任务1: 尝试获取锁2")
                with self.lock2:
                    print("任务1: 获得锁2,执行任务")
        
        def task2():
            print("任务2: 尝试获取锁2")
            with self.lock2:
                print("任务2: 获得锁2")
                time.sleep(1)
                print("任务2: 尝试获取锁1")
                with self.lock1:
                    print("任务2: 获得锁1,执行任务")
        
        thread1 = threading.Thread(target=task1)
        thread2 = threading.Thread(target=task2)
        thread1.start()
        thread2.start()
        
        # 等待一段时间,如果发生死锁,线程不会结束
        thread1.join(timeout=5)
        thread2.join(timeout=5)
        
        if thread1.is_alive() or thread2.is_alive():
            print("检测到死锁!")
        else:
            print("任务正常完成")
    
    def avoid_deadlock_with_ordering(self):
        print("\n=== 通过锁排序避免死锁 ===")
        
        def task_with_ordering(task_name, first_lock_idx, second_lock_idx):
            # 始终按照相同的顺序获取锁
            locks_to_acquire = sorted([first_lock_idx, second_lock_idx])
            
            print(f"{task_name}: 按顺序获取锁")
            with self.ordered_locks[locks_to_acquire[0]]:
                print(f"{task_name}: 获得锁{locks_to_acquire[0]}")
                time.sleep(1)
                with self.ordered_locks[locks_to_acquire[1]]:
                    print(f"{task_name}: 获得锁{locks_to_acquire[1]},执行任务")
        
        thread1 = threading.Thread(target=task_with_ordering, args=("任务1", 0, 1))
        thread2 = threading.Thread(target=task_with_ordering, args=("任务2", 1, 0))
        
        thread1.start()
        thread2.start()
        thread1.join()
        thread2.join()
        
        print("任务正常完成,无死锁")
    @contextmanager
    def timeout_lock(self, lock, timeout=5):
        """带超时的锁获取"""
        acquired = lock.acquire(timeout=timeout)
        if not acquired:
            raise TimeoutError("获取锁超时")
        try:
            yield
        finally:
            lock.release()
    
    def avoid_deadlock_with_timeout(self):
        print("\n=== 通过超时避免死锁 ===")

        def task_with_timeout(task_name, first_lock, second_lock):
            try:
                print(f"{task_name}: 尝试获取第一个锁")
                with self.timeout_lock(first_lock, timeout=2):
                    print(f"{task_name}: 获得第一个锁")
                    time.sleep(1)
                    print(f"{task_name}: 尝试获取第二个锁")
                    with self.timeout_lock(second_lock, timeout=2):
                        print(f"{task_name}: 获得第二个锁,执行任务")
            except TimeoutError as e:
                print(f"{task_name}: {e}")
        
        thread1 = threading.Thread(target=task_with_timeout, args=("任务1", self.lock1, self.lock2))
        thread2 = threading.Thread(target=task_with_timeout, args=("任务2", self.lock2, self.lock1))
        thread1.start()
        thread2.start()
        thread1.join()
        thread2.join()

# 运行死锁演示
deadlock_demo = DeadlockDemo()
# 注意:第一个演示可能会导致死锁,程序会挂起
# deadlock_demo.demonstrate_deadlock()
deadlock_demo.avoid_deadlock_with_ordering()
deadlock_demo.avoid_deadlock_with_timeout()

死锁是指两个或多个线程在等待对方持有的资源,导致所有线程都无法继续执行的状态。死锁会使程序挂起,严重影响系统稳定性。

演示方法:两个线程分别先后持有锁1和锁2,随后尝试获取对方持有的锁,形成循环等待。结果:线程相互等待对方释放锁,程序无法继续执行,造成死锁。检测:通过设置线程 join 的超时,判断线程是否仍然存活,进而检测死锁。

锁顺序策略

核心思想:所有线程按照相同的顺序申请锁,避免循环等待。实现:将锁编号排序,线程均按照锁编号的升序依次获取锁。效果:消除循环等待条件,避免死锁发生。示例:代码中 avoid_deadlock_with_ordering 方法演示了该策略。

锁超时策略

核心思想:尝试获取锁时设置超时时间,超时则放弃,避免无限等待。实现:使用自定义的上下文管理器 timeout_lock,调用 lock.acquire(timeout=...)效果:线程在超时后放弃锁请求,可以采取补救措施,避免死锁挂起。示例:代码中 avoid_deadlock_with_timeout 方法演示了该策略。

线程池与任务调度

自定义线程池实现

import threading
import queue
import time,random
from typing import Callable
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class Task:
    """任务封装类"""  
    def __init__(self, func: Callable, args: tuple = (), kwargs: dict = None, callback: Callable = None):
        self.func = func
        self.args = args
        self.kwargs = kwargs or {}
        self.callback = callback
        self.result = None
        self.exception = None
        self.completed = threading.Event()
    
    def execute(self):
        """执行任务"""
        try:
            self.result = self.func(*self.args, **self.kwargs)
        except Exception as e:
            self.exception = e
            logger.error(f"任务执行失败: {e}")
        finally:
            self.completed.set()
            if self.callback:
                self.callback(self)

class CustomThreadPool:
    """自定义线程池"""
    def __init__(self, max_workers: int = 4, queue_size: int = 100):
        self.max_workers = max_workers
        self.task_queue = queue.Queue(maxsize=queue_size)
        self.workers = []
        self.shutdown_flag = threading.Event()
        self.stats = {
            'tasks_submitted': 0,
            'tasks_completed': 0,
            'tasks_failed': 0
        }
        self.stats_lock = threading.Lock()
        self._start_workers() # 启动工作线程
    
    def _start_workers(self):
        """启动工作线程"""
        for i in range(self.max_workers):
            worker = threading.Thread(target=self._worker, args=(i,))
            worker.daemon = True
            worker.start()
            self.workers.append(worker)
            logger.info(f"启动工作线程 {i}")
    
    def _worker(self, worker_id: int):
        """工作线程主循环"""
        logger.info(f"工作线程 {worker_id} 开始运行")
        while not self.shutdown_flag.is_set():
            try:
                # 获取任务,设置超时避免无限等待
                task = self.task_queue.get(timeout=1)
                
                logger.info(f"工作线程 {worker_id} 开始执行任务")
                task.execute()
                with self.stats_lock:
                    if task.exception:
                        self.stats['tasks_failed'] += 1
                    else:
                        self.stats['tasks_completed'] += 1
                
                self.task_queue.task_done()
                logger.info(f"工作线程 {worker_id} 完成任务")
                
            except queue.Empty:
                continue
            except Exception as e:
                logger.error(f"工作线程 {worker_id} 发生错误: {e}")
        logger.info(f"工作线程 {worker_id} 退出")
    
    def submit(self, func: Callable, *args, **kwargs) -> Task:
        """提交任务"""
        if self.shutdown_flag.is_set():
            raise RuntimeError("线程池已关闭")
        
        task = Task(func, args, kwargs)
        self.task_queue.put(task)
        
        with self.stats_lock:
            self.stats['tasks_submitted'] += 1
        logger.info("任务已提交")
        return task
    
    def submit_with_callback(self, func: Callable, callback: Callable, *args, **kwargs) -> Task:
        """提交带回调的任务"""
        if self.shutdown_flag.is_set():
            raise RuntimeError("线程池已关闭")
        task = Task(func, args, kwargs, callback)
        self.task_queue.put(task)
        
        with self.stats_lock:
            self.stats['tasks_submitted'] += 1  
        return task
    def shutdown(self, wait: bool = True):
        """关闭线程池"""
        logger.info("开始关闭线程池")
        self.shutdown_flag.set()   
        if wait:
            # 等待所有任务完成
            self.task_queue.join()
            
            # 等待所有工作线程退出
            for worker in self.workers:
                worker.join()
        
        logger.info("线程池已关闭")
    def get_stats(self):
        """获取统计信息"""
        with self.stats_lock:
            return self.stats.copy()

def test_custom_thread_pool():
    """测试自定义线程池"""
    def sample_task(task_id: int, duration: float):
        """示例任务"""
        logger.info(f"任务 {task_id} 开始执行,预计耗时 {duration} 秒")
        time.sleep(duration)
        result = f"任务 {task_id} 完成"
        logger.info(result)
        return result
    
    def task_callback(task: Task):
        """任务完成回调"""
        if task.exception:
            logger.error(f"任务失败: {task.exception}")
        else:
            logger.info(f"任务成功完成: {task.result}")
    pool = CustomThreadPool(max_workers=3)   # 创建线程池

    # 提交任务
    tasks = []
    for i in range(8):
        duration = random.uniform(0.5, 2.0)
        task = pool.submit_with_callback(sample_task, task_callback, i, duration)
        tasks.append(task)
    
    # 等待所有任务完成
    for task in tasks:
        task.completed.wait()
    
    # 打印统计信息
    stats = pool.get_stats()
    logger.info(f"统计信息: {stats}")
    pool.shutdown() # 关闭线程池
test_custom_thread_pool()

线程池是一种管理和复用线程资源的技术,通过预先创建一定数量的线程,避免频繁创建和销毁线程带来的开销,提高程序的并发性能和资源利用率。线程池广泛应用于服务器、爬虫、异步任务处理等场景。

任务封装(Task 类):任务封装了执行函数、参数、执行结果、异常信息及完成状态,支持任务完成后的回调函数,方便异步处理结果。

线程池管理(CustomThreadPool 类):维护固定数量的工作线程,线程启动后不断从任务队列中取任务执行。使用线程安全的 queue.Queue 作为任务队列,支持阻塞等待和任务完成通知。通过事件标志 shutdown_flag 实现线程池的安全关闭。统计任务提交、完成和失败数量,方便监控线程池运行状态。支持任务提交时附带回调函数,实现任务完成后的异步处理。

线程安全:通过锁保护统计数据的读写,避免竞态条件。

灵活的任务提交接口:支持普通任务提交和带回调的任务提交,满足不同业务需求。

任务优先级调度

import heapq
import threading
import time
from dataclasses import dataclass, field
from typing import Callable
import uuid
@dataclass
class PriorityTask:
    """优先级任务"""
    priority: int
    func: Callable
    args: tuple = field(default_factory=tuple)
    kwargs: dict = field(default_factory=dict)
    task_id: str = field(default_factory=lambda: str(uuid.uuid4()))
    created_time: float = field(default_factory=time.time)
    
    def __lt__(self, other):
        # 优先级数字越小,优先级越高
        if self.priority != other.priority:
            return self.priority < other.priority
        # 优先级相同时,按创建时间排序
        return self.created_time < other.created_time
    def execute(self):
        """执行任务"""
        return self.func(*self.args, **self.kwargs)

class PriorityThreadPool:
    """支持优先级的线程池"""
    def __init__(self, max_workers: int = 4):
        self.max_workers = max_workers
        self.task_heap = []
        self.heap_lock = threading.Lock()
        self.task_available = threading.Condition(self.heap_lock)
        self.workers = []
        self.shutdown_flag = threading.Event()
        self.active_tasks = 0
        self.completed_tasks = 0
        self._start_workers()
    
    def _start_workers(self):
        """启动工作线程"""
        for i in range(self.max_workers):
            worker = threading.Thread(target=self._worker, args=(i,))
            worker.daemon = True
            worker.start()
            self.workers.append(worker)
    
    def _worker(self, worker_id: int):
        """工作线程"""
        while not self.shutdown_flag.is_set():
            with self.task_available:
                # 等待任务可用
                while not self.task_heap and not self.shutdown_flag.is_set():
                    self.task_available.wait(timeout=1)
                
                if self.shutdown_flag.is_set():
                    break
                
                if self.task_heap:
                    # 获取优先级最高的任务
                    task = heapq.heappop(self.task_heap)
                    self.active_tasks += 1
            
            try:
                print(f"工作线程 {worker_id} 执行任务 {task.task_id} (优先级: {task.priority})")
                result = task.execute()
                print(f"任务 {task.task_id} 完成,结果: {result}")
                
                with self.heap_lock:
                    self.active_tasks -= 1
                    self.completed_tasks += 1
                    
            except Exception as e:
                print(f"任务 {task.task_id} 执行失败: {e}")
                with self.heap_lock:
                    self.active_tasks -= 1
    
    def submit(self, func: Callable, priority: int = 5, *args, **kwargs) -> str:
        """提交任务"""
        task = PriorityTask(priority, func, args, kwargs)
        
        with self.task_available:
            heapq.heappush(self.task_heap, task)
            self.task_available.notify()
        
        print(f"提交任务 {task.task_id},优先级: {priority}")
        return task.task_id
    
    def shutdown(self, wait: bool = True):
        """关闭线程池"""
        self.shutdown_flag.set()
        with self.task_available:
            self.task_available.notify_all()
        
        if wait:
            for worker in self.workers:
                worker.join()
    
    def get_status(self):
        """获取状态"""
        with self.heap_lock:
            return {
                'pending_tasks': len(self.task_heap),
                'active_tasks': self.active_tasks,
                'completed_tasks': self.completed_tasks
            }

def test_priority_thread_pool():
    """测试优先级线程池"""
    
    def sample_task(task_name: str, duration: float):
        time.sleep(duration)
        return f"{task_name} 完成"
    pool = PriorityThreadPool(max_workers=2)
    
    # 提交不同优先级的任务
    tasks = [
        ("低优先级任务1", 10, 1.0),
        ("高优先级任务1", 1, 0.5),
        ("中优先级任务1", 5, 0.8),
        ("高优先级任务2", 1, 0.3),
        ("低优先级任务2", 10, 1.2),
        ("中优先级任务2", 5, 0.6),
    ]
    
    for task_name, priority, duration in tasks:
        pool.submit(sample_task, priority, task_name, duration)
        time.sleep(0.1)  # 稍微间隔一下提交时间
    
    # 监控执行状态
    while True:
        status = pool.get_status()
        print(f"状态: {status}")
        
        if status['pending_tasks'] == 0 and status['active_tasks'] == 0:
            break
        time.sleep(1)
    pool.shutdown()
test_priority_thread_pool()

高优先级任务(优先级数字小)优先被工作线程执行。

优先级相同的任务按提交时间先后顺序执行。

线程池能够动态调度任务,合理分配线程资源。线程池状态信息实时反映任务队列和执行情况。

任务封装(PriorityTask 类):使用 @dataclass 简化任务定义。

任务包含优先级、执行函数、参数、唯一任务ID和创建时间。实现了 __lt__ 方法,使任务可被优先队列(堆)根据优先级和提交时间排序。优先级数字越小,优先级越高;

优先级相同时,先提交的任务优先执行。

线程池实现(PriorityThreadPool 类):维护一个线程安全的任务堆(优先队列)和对应的锁与条件变量。

工作线程等待任务到来,取出优先级最高的任务执行。支持任务提交时指定优先级,默认优先级为5。统计任务的活跃数和完成数,方便监控线程池状态。支持安全关闭线程池,等待所有线程退出。

线程安全与同步:使用 threading.Lockthreading.Condition 保护任务堆的访问和线程间通知。工作线程在无任务时阻塞等待,避免忙等待。

实战项目:高并发任务处理系统

import threading
import queue
import time
import json
import logging
import random
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Optional, Callable, Dict, List
from enum import Enum

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("TaskSystem")

# 任务状态枚举
class TaskStatus(Enum):
    PENDING = "pending"
    PROCESSING = "processing"
    COMPLETED = "completed"
    FAILED = "failed"
    CANCELLED = "cancelled"

@dataclass
class Task:
    """任务数据类"""
    id: str
    data: Any
    priority: int = 0
    status: TaskStatus = TaskStatus.PENDING
    result: Optional[Any] = None
    error: Optional[str] = None
    created_at: float = time.time()
    started_at: Optional[float] = None
    completed_at: Optional[float] = None
    
    def __lt__(self, other):
        """用于优先级队列比较,优先级数字越小优先级越高"""
        return self.priority < other.priority

class TaskHandler(ABC):
    """任务处理器抽象基类"""
    
    @abstractmethod
    def handle(self, task: Task) -> Any:
        """处理任务并返回结果"""
        pass
    
    @abstractmethod
    def on_success(self, task: Task) -> None:
        """任务成功完成时的回调"""
        pass
    
    @abstractmethod
    def on_failure(self, task: Task) -> None:
        """任务失败时的回调"""
        pass

class SimpleTaskHandler(TaskHandler):
    """简单任务处理器示例"""
    
    def handle(self, task: Task) -> Any:
        """模拟任务处理,随机成功或失败"""
        processing_time = random.uniform(0.1, 1.0)
        time.sleep(processing_time)  # 模拟处理时间
        
        # 随机决定任务成功或失败
        if random.random() < 0.8:  # 80%的成功率
            return f"Processed task {task.id} in {processing_time:.2f}s"
        else:
            raise Exception(f"Task {task.id} failed during processing")
    
    def on_success(self, task: Task) -> None:
        logger.info(f"Task {task.id} completed successfully: {task.result}")
    
    def on_failure(self, task: Task) -> None:
        logger.error(f"Task {task.id} failed: {task.error}")

class TaskWorker(threading.Thread):
    """任务工作线程"""
    
    def __init__(self, worker_id: int, task_queue, result_queue, handler: TaskHandler):
        super().__init__()
        self.worker_id = worker_id
        self.task_queue = task_queue
        self.result_queue = result_queue
        self.handler = handler
        self.daemon = True
        self.running = False
        
    def run(self):
        """工作线程主循环"""
        self.running = True
        logger.info(f"Worker {self.worker_id} started")
        
        while self.running:
            try:
                # 获取任务,设置超时以便定期检查运行状态
                task = self.task_queue.get(timeout=1)
                if task is None:  # 接收到停止信号
                    break
                    
                # 检查任务是否已被取消
                if task.status == TaskStatus.CANCELLED:
                    logger.info(f"Worker {self.worker_id} skipping cancelled task {task.id}")
                    self.task_queue.task_done()
                    continue
                    
                # 更新任务状态为处理中
                task.status = TaskStatus.PROCESSING
                task.started_at = time.time()
                
                try:
                    # 处理任务
                    result = self.handler.handle(task)
                    task.status = TaskStatus.COMPLETED
                    task.result = result
                    self.handler.on_success(task)
                except Exception as e:
                    task.status = TaskStatus.FAILED
                    task.error = str(e)
                    self.handler.on_failure(task)
                    logger.error(f"Worker {self.worker_id} failed to process task {task.id}: {e}")
                
                # 记录完成时间
                task.completed_at = time.time()
                
                # 将结果放入结果队列
                self.result_queue.put(task)
                logger.info(f"Worker {self.worker_id} completed task {task.id}")
                
                # 标记任务完成
                self.task_queue.task_done()
                
            except queue.Empty:
                continue
                
        logger.info(f"Worker {self.worker_id} stopped")
        
    def stop(self):
        """停止工作线程"""
        self.running = False

class TaskManager:
    """任务管理器"""
    
    def __init__(self, num_workers: int = 4, max_queue_size: int = 100):
        self.task_queue = queue.PriorityQueue(maxsize=max_queue_size)
        self.result_queue = queue.Queue()
        self.workers = []
        self.handler = SimpleTaskHandler()
        self.num_workers = num_workers
        self.task_counter = 0
        self.task_lock = threading.Lock()
        self.tasks = {}  # 存储所有任务
        self.running = False
        
    def start(self):
        """启动任务管理器和工作线程"""
        if self.running:
            logger.warning("Task manager is already running")
            return
            
        logger.info(f"Starting task manager with {self.num_workers} workers")
        self.running = True
        
        # 创建工作线程
        for i in range(self.num_workers):
            worker = TaskWorker(i, self.task_queue, self.result_queue, self.handler)
            self.workers.append(worker)
            worker.start()
            
        # 启动结果处理线程
        self.result_thread = threading.Thread(target=self._process_results, daemon=True)
        self.result_thread.start()
        
        logger.info("Task manager started successfully")
        
    def stop(self):
        """停止任务管理器和工作线程"""
        if not self.running:
            logger.warning("Task manager is not running")
            return
            
        logger.info("Stopping task manager")
        self.running = False
        
        # 停止所有工作线程
        for worker in self.workers:
            worker.stop()
            
        # 等待所有工作线程完成
        for worker in self.workers:
            worker.join(timeout=5)
            
        logger.info("Task manager stopped")
        
    def submit_task(self, data: Any, priority: int = 0) -> str:
        """提交新任务到任务队列"""
        with self.task_lock:
            task_id = f"task_{self.task_counter}"
            self.task_counter += 1
            
        task = Task(id=task_id, data=data, priority=priority)
        
        # 将任务添加到任务字典
        self.tasks[task_id] = task
        
        try:
            # 将任务放入优先级队列
            self.task_queue.put(task, timeout=1)
            logger.info(f"Task {task_id} submitted with priority {priority}")
            return task_id
        except queue.Full:
            logger.error(f"Task queue is full, failed to submit task {task_id}")
            return None
            
    def get_task_status(self, task_id: str) -> Optional[TaskStatus]:
        """获取任务状态"""
        task = self.tasks.get(task_id)
        return task.status if task else None
        
    def cancel_task(self, task_id: str) -> bool:
        """取消任务"""
        task = self.tasks.get(task_id)
        if task and task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
            task.status = TaskStatus.CANCELLED
            logger.info(f"Task {task_id} cancelled")
            return True
        return False
        
    def get_result(self, task_id: str) -> Optional[Any]:
        """获取任务结果"""
        task = self.tasks.get(task_id)
        return task.result if task and task.status == TaskStatus.COMPLETED else None
        
    def _process_results(self):
        """处理结果队列中的任务结果"""
        while self.running:
            try:
                # 从结果队列获取任务
                task = self.result_queue.get(timeout=1)
                if task is None:
                    break
                    
                # 更新任务状态
                if task.id in self.tasks:
                    self.tasks[task.id] = task
                    
                # 标记结果处理完成
                self.result_queue.task_done()
                
            except queue.Empty:
                continue
                
    def wait_all(self, timeout: Optional[float] = None) -> bool:
        """等待所有任务完成"""
        return self.task_queue.join()
        
    def get_stats(self) -> Dict[str, int]:
        """获取系统统计信息"""
        stats = {
            "total_tasks": len(self.tasks),
            "pending_tasks": 0,
            "processing_tasks": 0,
            "completed_tasks": 0,
            "failed_tasks": 0,
            "cancelled_tasks": 0,
        }
        
        for task in self.tasks.values():
            if task.status == TaskStatus.PENDING:
                stats["pending_tasks"] += 1
            elif task.status == TaskStatus.PROCESSING:
                stats["processing_tasks"] += 1
            elif task.status == TaskStatus.COMPLETED:
                stats["completed_tasks"] += 1
            elif task.status == TaskStatus.FAILED:
                stats["failed_tasks"] += 1
            elif task.status == TaskStatus.CANCELLED:
                stats["cancelled_tasks"] += 1
                
        return stats

# 示例使用
if __name__ == "__main__":
    # 创建任务管理器
    manager = TaskManager(num_workers=4, max_queue_size=50)
    
    try:
        # 启动任务管理器
        manager.start()
        
        # 提交一些任务
        task_ids = []
        for i in range(20):
            priority = random.randint(0, 5)
            task_id = manager.submit_task(f"Data for task {i}", priority)
            if task_id:
                task_ids.append(task_id)
        
        # 随机取消一个任务
        if task_ids:
            cancel_id = random.choice(task_ids)
            manager.cancel_task(cancel_id)
        
        # 等待所有任务完成
        manager.wait_all()
        
        # 打印统计信息
        stats = manager.get_stats()
        print("\nTask Processing Statistics:")
        for key, value in stats.items():
            print(f"{key}: {value}")
            
        # 打印一些任务结果
        print("\nSample Task Results:")
        for i, task_id in enumerate(task_ids[:5]):
            result = manager.get_result(task_id)
            status = manager.get_task_status(task_id)
            print(f"Task {task_id}: Status={status}, Result={result}")
            
    finally:
        # 停止任务管理器
        manager.stop()

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值