JDK 8 ForkJoinPool源码详解(详细注释版)

JDK 8 ForkJoinPool源码详解(详细注释版)

1. ForkJoinPool核心类源码

/*
 * ForkJoinPool是Java 7引入的并行执行框架的核心类
 * 实现了工作窃取算法(Work-Stealing Algorithm)
 * 专门用于执行ForkJoinTask类型的任务
 */
public class ForkJoinPool extends AbstractExecutorService {
    
    /*
     * 线程池状态和配置常量
     */
    
    // 线程池运行状态位数
    private static final int  COUNT_BITS = (1 << 15) - 1;  // 32767
    
    // 线程池最大容量
    private static final int  MAX_CAP    = 0x7fff;         // 32767
    
    // 线程池状态标志
    private static final int  FIFO       = 1 << 15;        // FIFO队列标志
    private static final int  CLEAR_TLS  = 1 << 16;        // 清除TLS标志
    private static final int  OWNING     = 1 << 17;        // 拥有者标志
    private static final int  TERMINATED = 1 << 18;        // 终止标志
    private static final int  STOP       = 1 << 19;        // 停止标志
    private static final int  SHUTDOWN   = 1 << 20;        // 关闭标志
    private static final int  STARTED    = 1 << 21;        // 已启动标志
    
    // 默认并行级别(通常等于CPU核心数)
    private static final int  DEFAULT_PARALLELISM = 
        Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors());
    
    // 默认配置
    private static final int  DEFAULT_COMMON_MAX_SPARES = 256;
    
    // 随机数生成器种子
    private static final int  SEED_INCREMENT = 0x9e3779b9;
    private static final int  PROBE_INCREMENT = 0x9e3779b9;
    private static final int  SECONDARY = 0x3c6ef35f;
    
    /*
     * 字段定义
     */
    
    // 控制状态字段,包含线程池状态和工作线程数
    volatile int ctl;
    
    // 运行状态字段
    volatile int runState;
    
    // 配置字段,包含并行级别和配置标志
    final int config;
    
    // 工作队列数组,大小为2的幂
    volatile WorkQueue[] workQueues;
    
    // 工厂用于创建新的工作线程
    final ForkJoinWorkerThreadFactory factory;
    
    // 未捕获异常处理器
    final UncaughtExceptionHandler ueh;
    
    // 名称前缀
    final String workerNamePrefix;
    
    // 公共池实例
    private static final ForkJoinPool common;
    
    // 线程阻塞时使用的条件
    private final Object registrationLock = new Object();
    
    // 阻塞线程数量
    volatile long stealCount;
    
    // 饱和或终止时的等待者
    private static final sun.misc.Unsafe U = sun.misc.Unsafe.getUnsafe();
    private static final long CTL;
    private static final long RUNSTATE;
    private static final long STEALCOUNT;
    private static final long PARKBLOCKER;
    private static final long QTOP;
    private static final long QLOCK;
    private static final long QSCANSTATE;
    private static final long QPARKER;
    private static final int  ABASE;
    private static final int  ASHIFT;
    
    static {
        try {
            CTL = U.objectFieldOffset
                (ForkJoinPool.class.getDeclaredField("ctl"));
            RUNSTATE = U.objectFieldOffset
                (ForkJoinPool.class.getDeclaredField("runState"));
            STEALCOUNT = U.objectFieldOffset
                (ForkJoinPool.class.getDeclaredField("stealCount"));
            PARKBLOCKER = U.objectFieldOffset
                (java.lang.Thread.class.getDeclaredField("parkBlocker"));
            QTOP = U.objectFieldOffset
                (WorkQueue.class.getDeclaredField("top"));
            QLOCK = U.objectFieldOffset
                (WorkQueue.class.getDeclaredField("lock"));
            QSCANSTATE = U.objectFieldOffset
                (WorkQueue.class.getDeclaredField("scanState"));
            QPARKER = U.objectFieldOffset
                (WorkQueue.class.getDeclaredField("parker"));
            Class<?> ak = ForkJoinTask[].class;
            ABASE = U.arrayBaseOffset(ak);
            int scale = U.arrayIndexScale(ak);
            if ((scale & (scale - 1)) != 0)
                throw new Error("data type scale not a power of two");
            ASHIFT = 31 - Integer.numberOfLeadingZeros(scale);
        } catch (Exception e) {
            throw new Error(e);
        }
        
        common = java.security.AccessController.doPrivileged(
            new java.security.PrivilegedAction<ForkJoinPool>() {
                public ForkJoinPool run() { return makeCommonPool(); }});
    }
    
    /**
     * 工作队列类
     * 每个工作线程都有一个关联的工作队列
     */
    static final class WorkQueue {
        
        // 队列锁
        volatile int lock;
        
        // 扫描状态
        volatile int scanState;
        
        // 队列顶部索引
        volatile int top;
        
        // 队列底部索引
        int base;
        
        // 队列数组
        ForkJoinTask<?>[] array;
        
        // 关联的线程
        final ForkJoinWorkerThread owner;
        
        // 队列池索引
        int poolIndex;
        
        // 随机探针值
        int seed;
        
        // 窃取者数量
        volatile int stealHint;
        
        // 线程阻塞器
        volatile Thread parker;
        
        // 等待加入的线程
        volatile ForkJoinTask<?> currentJoin;
        
        // 等待信号的线程
        volatile ForkJoinTask<?> currentSteal;
        
        // 队列构造函数
        WorkQueue(ForkJoinWorkerThread owner) {
            this.owner = owner;
            this.array = new ForkJoinTask<?>[INITIAL_QUEUE_CAPACITY];
        }
        
        /**
         * 将任务压入队列顶部(LIFO)
         */
        final void push(ForkJoinTask<?> task) {
            ForkJoinTask<?>[] a; 
            int s = top;
            if ((a = array) != null) {
                int m = a.length - 1;
                U.putOrderedObject(a, ((m & s) << ASHIFT) + ABASE, task);
                U.putOrderedInt(this, QTOP, s + 1);
            }
        }
        
        /**
         * 从队列顶部弹出任务(LIFO)
         */
        final ForkJoinTask<?> pop() {
            ForkJoinTask<?>[] a; 
            int s;
            if ((a = array) != null && (s = top) != base) {
                long j = (((a.length - 1) & --s) << ASHIFT) + ABASE;
                ForkJoinTask<?> t = (ForkJoinTask<?>)U.getObject(a, j);
                if (t != null) {
                    U.putObjectVolatile(a, j, null);
                    U.putOrderedInt(this, QTOP, s);
                }
                return t;
            }
            return null;
        }
        
        /**
         * 从队列底部取出任务(FIFO)
         */
        final ForkJoinTask<?> poll() {
            ForkJoinTask<?>[] a; 
            int b; 
            if ((a = array) != null && (b = base) != top) {
                int m = a.length - 1;
                long j = ((m & b) << ASHIFT) + ABASE;
                ForkJoinTask<?> t = (ForkJoinTask<?>)U.getObjectVolatile(a, j);
                if (t != null && base == b && 
                    U.compareAndSwapObject(a, j, t, null)) {
                    base = b + 1;
                    return t;
                }
            }
            return null;
        }
        
        /**
         * 窃取任务
         */
        final ForkJoinTask<?> steal() {
            return poll();
        }
        
        /**
         * 获取队列大小
         */
        final int size() {
            int n = top - base;
            return (n < 0) ? 0 : n;
        }
    }
    
    /**
     * ForkJoinTask基类
     * 表示可以被ForkJoinPool执行的任务
     */
    public abstract class ForkJoinTask<V> implements Future<V>, Serializable {
        
        // 任务状态
        volatile int status;
        
        // 任务完成状态常量
        private static final int DONE_MASK   = 0xf0000000;  // 完成位掩码
        private static final int NORMAL      = 0xf0000000;  // 正常完成
        private static final int CANCELLED   = 0xc0000000;  // 已取消
        private static final int EXCEPTIONAL = 0x80000000;  // 异常完成
        private static final int SIGNAL      = 0x00010000;  // 信号位
        
        /**
         * 执行任务的主要方法
         * 子类必须实现此方法
         */
        protected abstract boolean exec();
        
        /**
         * 分解任务
         */
        public final ForkJoinTask<V> fork() {
            Thread t;
            if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) {
                ((ForkJoinWorkerThread)t).workQueue.push(this);
            } else {
                ForkJoinPool.common.externalPush(this);
            }
            return this;
        }
        
        /**
         * 等待并获取任务结果
         */
        public final V join() {
            int s;
            if ((s = doJoin() & DONE_MASK) != NORMAL)
                reportException(s);
            return getRawResult();
        }
        
        /**
         * 执行任务并获取结果
         */
        public final V invoke() {
            int s;
            if ((s = doInvoke() & DONE_MASK) != NORMAL)
                reportException(s);
            return getRawResult();
        }
        
        /**
         * 执行连接操作
         */
        private int doJoin() {
            int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
            return (s = status) < 0 ? s :
                ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
                (w = (wt = (ForkJoinWorkerThread)t).workQueue).
                tryUnpush(this) && (s = doExec()) < 0 ? s :
                wt.pool.awaitJoin(w, this, 0L) :
                externalAwaitDone();
        }
        
        /**
         * 执行调用操作
         */
        private int doInvoke() {
            int s; Thread t; ForkJoinWorkerThread wt;
            return (s = doExec()) < 0 ? s :
                ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
                (wt = (ForkJoinWorkerThread)t).pool.
                awaitJoin(wt.workQueue, this, 0L) :
                externalAwaitDone();
        }
        
        /**
         * 执行任务
         */
        final int doExec() {
            int s; boolean completed;
            if ((s = status) >= 0) {
                try {
                    completed = exec();
                } catch (Throwable rex) {
                    return setExceptionalCompletion(rex);
                }
                if (completed)
                    s = setCompletion(NORMAL);
            }
            return s;
        }
        
        /**
         * 设置完成状态
         */
        private int setCompletion(int completion) {
            for (int s;;) {
                if ((s = status) < 0)
                    return s;
                if (U.compareAndSwapInt(this, STATUS, s, s | completion)) {
                    if ((s >>> 16) != 0)
                        synchronized (this) { notifyAll(); }
                    return completion;
                }
            }
        }
        
        /**
         * 获取原始结果
         */
        public abstract V getRawResult();
        
        /**
         * 设置原始结果
         */
        protected abstract void setRawResult(V value);
        
        /**
         * 取消任务
         */
        public boolean cancel(boolean mayInterruptIfRunning) {
            return (setCompletion(CANCELLED) & DONE_MASK) == CANCELLED;
        }
        
        /**
         * 判断任务是否完成
         */
        public final boolean isDone() {
            return status < 0;
        }
        
        /**
         * 判断任务是否被取消
         */
        public final boolean isCancelled() {
            return (status & DONE_MASK) == CANCELLED;
        }
        
        /**
         * 获取任务结果
         */
        public final V get() throws InterruptedException, ExecutionException {
            int s = (Thread.currentThread() instanceof ForkJoinWorkerThread) ?
                doJoin() : externalInterruptibleAwaitDone();
            Throwable ex;
            if ((s &= DONE_MASK) == CANCELLED)
                throw new CancellationException();
            if (s == EXCEPTIONAL && (ex = getThrowableException()) != null)
                throw new ExecutionException(ex);
            return getRawResult();
        }
        
        /**
         * 带超时的获取任务结果
         */
        public final V get(long timeout, TimeUnit unit)
            throws InterruptedException, ExecutionException, TimeoutException {
            if (Thread.interrupted())
                throw new InterruptedException();
            // 简化实现,实际实现更复杂
            return get();
        }
    }
    
    /**
     * 递归任务抽象类
     * 用于实现分治算法
     */
    public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
        private static final long serialVersionUID = 5232453952276485270L;
        
        V result;
        
        protected abstract V compute();
        
        @Override
        public final V getRawResult() {
            return result;
        }
        
        @Override
        protected final void setRawResult(V value) {
            result = value;
        }
        
        @Override
        protected final boolean exec() {
            result = compute();
            return true;
        }
    }
    
    /**
     * 递归操作抽象类
     * 用于执行不返回结果的操作
     */
    public abstract class RecursiveAction extends ForkJoinTask<Void> {
        private static final long serialVersionUID = 5232453952276485270L;
        
        protected abstract void compute();
        
        @Override
        public final Void getRawResult() { return null; }
        
        @Override
        protected final void setRawResult(Void mustBeNull) { }
        
        @Override
        protected final boolean exec() {
            compute();
            return true;
        }
    }
    
    /*
     * ForkJoinPool核心方法实现
     */
    
    /**
     * 构造方法 - 创建指定并行级别的ForkJoinPool
     * 
     * @param parallelism 并行级别
     */
    public ForkJoinPool(int parallelism) {
        this(parallelism, defaultForkJoinWorkerThreadFactory, null, false);
    }
    
    /**
     * 构造方法 - 创建完整配置的ForkJoinPool
     * 
     * @param parallelism 并行级别
     * @param factory 工作线程工厂
     * @param handler 未捕获异常处理器
     * @param asyncMode 异步模式标志
     */
    public ForkJoinPool(int parallelism,
                        ForkJoinWorkerThreadFactory factory,
                        UncaughtExceptionHandler handler,
                        boolean asyncMode) {
        this(checkParallelism(parallelism),
             checkFactory(factory),
             handler,
             asyncMode ? FIFO : 0,
             "ForkJoinPool-" + nextPoolId() + "-worker-");
    }
    
    /**
     * 私有构造方法
     */
    private ForkJoinPool(int parallelism,
                         ForkJoinWorkerThreadFactory factory,
                         UncaughtExceptionHandler handler,
                         int mode,
                         String workerNamePrefix) {
        this.workerNamePrefix = workerNamePrefix;
        this.factory = factory;
        this.ueh = handler;
        this.config = (parallelism & 0xffff) | mode;
        long np = (long)(-parallelism); // offset ctl counts
        this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
    }
    
    /**
     * 检查并行级别参数
     */
    private static int checkParallelism(int parallelism) {
        if (parallelism <= 0 || parallelism > MAX_CAP)
            throw new IllegalArgumentException();
        return parallelism;
    }
    
    /**
     * 检查工厂参数
     */
    private static ForkJoinWorkerThreadFactory checkFactory
        (ForkJoinWorkerThreadFactory factory) {
        if (factory == null)
            throw new NullPointerException();
        return factory;
    }
    
    /**
     * 获取下一个池ID
     */
    private static synchronized int nextPoolId() {
        return ++poolNumberSequence;
    }
    private static int poolNumberSequence;
    
    /**
     * 获取公共池实例
     */
    public static ForkJoinPool commonPool() {
        return common;
    }
    
    /**
     * 创建公共池
     */
    private static ForkJoinPool makeCommonPool() {
        int parallelism = DEFAULT_PARALLELISM;
        ForkJoinWorkerThreadFactory factory = null;
        UncaughtExceptionHandler handler = null;
        try {
            String pp = System.getProperty
                ("java.util.concurrent.ForkJoinPool.common.parallelism");
            String fp = System.getProperty
                ("java.util.concurrent.ForkJoinPool.common.threadFactory");
            String hp = System.getProperty
                ("java.util.concurrent.ForkJoinPool.common.exceptionHandler");
            if (pp != null)
                parallelism = Integer.parseInt(pp);
            if (fp != null)
                factory = ((ForkJoinWorkerThreadFactory)ClassLoader.
                           getSystemClassLoader().loadClass(fp).newInstance());
            if (hp != null)
                handler = ((UncaughtExceptionHandler)ClassLoader.
                           getSystemClassLoader().loadClass(hp).newInstance());
        } catch (Exception ignore) {
        }
        if (factory == null) {
            if (System.getSecurityManager() == null)
                factory = defaultForkJoinWorkerThreadFactory;
            else // use security-managed default
                factory = new InnocuousForkJoinWorkerThreadFactory();
        }
        return new ForkJoinPool(parallelism, factory, handler, LIFO_QUEUE,
                                "ForkJoinPool.commonPool-worker-");
    }
    
    /**
     * 默认工作线程工厂
     */
    static final ForkJoinWorkerThreadFactory defaultForkJoinWorkerThreadFactory;
    static {
        defaultForkJoinWorkerThreadFactory =
            new DefaultForkJoinWorkerThreadFactory();
    }
    
    static final class DefaultForkJoinWorkerThreadFactory
        implements ForkJoinWorkerThreadFactory {
        public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
            return new ForkJoinWorkerThread(pool);
        }
    }
    
    /**
     * 执行ForkJoinTask
     */
    @Override
    public void execute(ForkJoinTask<?> task) {
        if (task == null)
            throw new NullPointerException();
        externalPush(task);
    }
    
    /**
     * 提交ForkJoinTask
     */
    @Override
    public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
        if (task == null)
            throw new NullPointerException();
        externalPush(task);
        return task;
    }
    
    /**
     * 提交Callable任务
     */
    @Override
    public <T> ForkJoinTask<T> submit(Callable<T> task) {
        return new ForkJoinTask.AdaptedCallable<T>(task).fork();
    }
    
    /**
     * 提交Runnable任务
     */
    @Override
    public ForkJoinTask<?> submit(Runnable task) {
        return new ForkJoinTask.AdaptedRunnable<Void>(task, null).fork();
    }
    
    /**
     * 提交Runnable任务并指定结果
     */
    @Override
    public <T> ForkJoinTask<T> submit(Runnable task, T result) {
        return new ForkJoinTask.AdaptedRunnable<T>(task, result).fork();
    }
    
    /**
     * 外部任务推送
     */
    final void externalPush(ForkJoinTask<?> task) {
        WorkQueue[] ws; WorkQueue q; int m;
        int r = ThreadLocalRandom.getProbe();
        int am = (config & FIFO) | r; // mode bits
        if ((ws = workQueues) != null && (m = ws.length - 1) >= 0 &&
            (q = ws[m & r & SQMASK]) != null && q.base == q.top) { // very conservative
            q.push(task);
            if (q.base != q.top)
                signalWork(ws, q);
        }
        else
            externalSubmit(task);
    }
    
    /**
     * 外部任务提交
     */
    private void externalSubmit(ForkJoinTask<?> task) {
        int r;                                    // initialize caller's probe
        if ((r = ThreadLocalRandom.getProbe()) == 0) {
            ThreadLocalRandom.localInit();
            r = ThreadLocalRandom.getProbe();
        }
        for (;;) {
            WorkQueue[] ws; WorkQueue q; int rs, m, k;
            boolean move = false;
            if ((rs = runState) < 0) {
                tryTerminate(false, false);       // help terminate
                throw new RejectedExecutionException();
            }
            else if ((ws = workQueues) != null && (m = ws.length - 1) >= 0) {
                if ((q = ws[k = r & m]) == null) {
                    WorkQueue nq = new WorkQueue(null);
                    nq.seed = r;
                    nq.hint = r;                  // publication racy
                    int i = (((r << 2) ^ r) << 17) ^ r; // xorshift
                    if (U.compareAndSwapObject(ws, ((long)k << ASHIFT) + ABASE, null, nq))
                        move = true;
                }
                else if (!U.compareAndSwapInt(q, QLOCK, 0, 1))
                    move = true;                  // move and restart
                else {
                    if (q.base == q.top)
                        q.push(task);
                    else
                        q.transferStealCount(this);
                    q.lock = 0;
                    if (q.base != q.top)
                        signalWork(ws, q);
                    break;
                }
            }
            else if ((rs & STARTED) == 0 || (ws = workQueues) == null)
                tryInitializeWorker();
            if (move)
                r = ThreadLocalRandom.advanceProbe(r);
        }
    }
    
    /**
     * 信号工作线程
     */
    final void signalWork(WorkQueue[] ws, WorkQueue q) {
        long c; int sp, i; WorkQueue v; Thread p;
        while ((c = ctl) < 0L) {                       // too few active
            if ((sp = (int)c) == 0) {                  // no idle workers
                if ((c & ADD_WORKER) != 0L)            // too few total
                    createWorker();
                break;
            }
            if (ws == null || ws.length <= (i = sp & SMASK))
                break;
            if ((v = ws[i]) == null)
                break;
            int vs = (sp + SS_SEQ) & ~INACTIVE;        // next scanState
            int vc = (int)(c >> AC_SHIFT) + (v.scanState < 0 ? 0 : 1);
            long nc = (v.stackPred & SMASK) | ((long)vc << AC_SHIFT) |
                ((long)(v.nextWait & E_MASK) << TC_SHIFT) |
                ((long)(v.seed & 0xffffffffL) << 32);
            if (U.compareAndSwapLong(this, CTL, c, nc)) {
                v.scanState = vs;
                if ((p = v.parker) != null)
                    U.unpark(p);
                break;
            }
        }
    }
    
    /**
     * 创建工作线程
     */
    private boolean createWorker() {
        ForkJoinWorkerThreadFactory fac = factory;
        Throwable ex = null;
        ForkJoinWorkerThread wt = null;
        try {
            if (fac != null && (wt = fac.newThread(this)) != null) {
                wt.start();
                return true;
            }
        } catch (Throwable rex) {
            ex = rex;
        }
        deregisterWorker(wt, ex);
        return false;
    }
    
    /**
     * 注销工作线程
     */
    final void deregisterWorker(ForkJoinWorkerThread wt, Throwable ex) {
        WorkQueue w = null;
        if (wt != null && (w = wt.workQueue) != null) {
            int idx = w.poolIndex;
            w.poolIndex = 0;                // invalidate
            WorkQueue[] ws = workQueues;
            if (ws != null && idx >= 0 && idx < ws.length)
                ws[idx] = null;             // remove index
        }
        long c;                             // decrement counts
        do {} while (!U.compareAndSwapLong
                     (this, CTL, c = ctl, ((c & ~TC_MASK) |
                                           (((c & TC_MASK) - TC_UNIT) & TC_MASK))));
        if (w != null) {
            synchronized (this) {
                if (w.array != null) {      // check if previously enabled
                    w.array = null;         // disable
                    if (!w.queuing)         // remove from registry
                        deregisterWorker(w);
                }
            }
        }
        if (ex == null)                     // help clean on way out
            ForkJoinTask.helpExpungeStaleExceptions();
        else                                // rethrow
            U.throwException(ex);
    }
    
    /**
     * 等待任务完成
     */
    final int awaitJoin(WorkQueue w, ForkJoinTask<?> task, long deadline) {
        int s = 0;
        if (task != null && w != null) {
            ForkJoinTask<?> prevJoin = w.currentJoin;
            w.currentJoin = task;
            if (task.status >= 0) {
                if (w.base == w.top || (s = tryHelpStealer(w, task)) >= 0)
                    s = w.qlock == 0 ? task.doJoin() : 0;
            }
            w.currentJoin = prevJoin;
        }
        return s;
    }
    
    /**
     * 尝试帮助窃取者
     */
    private int tryHelpStealer(WorkQueue w, ForkJoinTask<?> task) {
        int oldSum = 0;
        WorkQueue[] ws;
        if ((ws = workQueues) != null) {
            for (int retries = 0, sum = 0;;) {
                WorkQueue u = w;            // extract ordering for lock
                int i = w.seed & (ws.length - 1);
                WorkQueue v = ws[i];
                if (v != null && v != w && v.base != v.top) {
                    if (u.qlock == 0 && v.qlock == 0 &&
                        U.compareAndSwapInt(u, QLOCK, 0, 1)) {
                        ForkJoinTask<?>[] a = v.array;
                        int b = v.base, m = a != null ? a.length - 1 : 0;
                        if (v.base == b && a != null && m > 0) {
                            ForkJoinTask<?> t = (ForkJoinTask<?>)
                                U.getObjectVolatile(a, ((long)m & b) << ASHIFT);
                            if (t != null && v.base == b &&
                                U.compareAndSwapObject(a, ((long)m & b) << ASHIFT,
                                                       t, null)) {
                                v.base = b + 1;
                                u.qlock = 0;
                                w.runTask(t);
                                return 0;
                            }
                        }
                        u.qlock = 0;
                    }
                }
                if (++sum == (oldSum << 1) + 1) {
                    if ((oldSum = sum) == 0)
                        return 0;
                    if (retries++ > 100)
                        break;
                }
            }
        }
        return -1;
    }
    
    /**
     * 运行任务
     */
    final void runTask(ForkJoinTask<?> task) {
        if (task != null) {
            (currentForkJoinTask = task).doExec();
            currentForkJoinTask = null;
        }
    }
    
    /**
     * 获取活动线程数
     */
    public int getActiveThreadCount() {
        int c = ctl;
        int ac = (short)(c >>> AC_SHIFT);
        return (ac >= 0) ? ac : 0;
    }
    
    /**
     * 判断是否处于静止状态
     */
    public boolean isQuiescent() {
        return (ctl & AC_MASK) == 0;
    }
    
    /**
     * 获取窃取计数
     */
    public long getStealCount() {
        long count = stealCount;
        WorkQueue[] ws; WorkQueue w;
        if ((ws = workQueues) != null) {
            for (int i = 0; i < ws.length; ++i) {
                if ((w = ws[i]) != null)
                    count += w.nsteals;
            }
        }
        return count;
    }
    
    /**
     * 获取并行级别
     */
    public int getParallelism() {
        int par = config & 0xffff;
        return (par > 0) ? par : 1;
    }
    
    /**
     * 获取公共池的并行级别
     */
    public static int getCommonPoolParallelism() {
        return commonParallelism;
    }
    private static final int commonParallelism =
        (short)(common != null ? common.config : 0);
    
    /**
     * 关闭线程池
     */
    @Override
    public void shutdown() {
        final ReentrantLock lock = this.submissionQueuesLock;
        lock.lock();
        try {
            if ((runState & SHUTDOWN) == 0) {
                runState |= SHUTDOWN;
                terminate(false);
            }
        } finally {
            lock.unlock();
        }
    }
    
    /**
     * 立即关闭线程池
     */
    @Override
    public List<Runnable> shutdownNow() {
        final ReentrantLock lock = this.submissionQueuesLock;
        lock.lock();
        try {
            if ((runState & (STOP | TERMINATED)) == 0) {
                runState |= STOP;
                for (int i = 0, n = workQueues.length; i < n; ++i) {
                    WorkQueue w = workQueues[i];
                    if (w != null)
                        w.cancelAll();
                }
                terminate(false);
            }
        } finally {
            lock.unlock();
        }
        return Collections.emptyList();
    }
    
    /**
     * 判断线程池是否已关闭
     */
    @Override
    public boolean isShutdown() {
        return runState >= 0;
    }
    
    /**
     * 判断线程池是否已终止
     */
    @Override
    public boolean isTerminated() {
        return (runState & TERMINATED) != 0;
    }
    
    /**
     * 等待线程池终止
     */
    @Override
    public boolean awaitTermination(long timeout, TimeUnit unit)
        throws InterruptedException {
        long nanos = unit.toNanos(timeout);
        final ReentrantLock lock = this.submissionQueuesLock;
        lock.lock();
        try {
            for (;;) {
                if ((runState & TERMINATED) != 0)
                    return true;
                if (nanos <= 0)
                    return false;
                nanos = termination.awaitNanos(nanos);
            }
        } finally {
            lock.unlock();
        }
    }
    
    /**
     * 终止线程池
     */
    private void terminate(boolean now) {
        runState |= (now ? STOP : SHUTDOWN);
        signalWork();
        for (int i = 0, n = workQueues.length; i < n; ++i) {
            WorkQueue w = workQueues[i];
            if (w != null)
                w.cancelAll();
        }
        runState |= TERMINATED;
        termination.signalAll();
    }
}

2. ForkJoinWorkerThread工作线程类

/*
 * ForkJoinWorkerThread是ForkJoinPool的工作线程实现
 * 每个工作线程都关联一个工作队列
 */
public class ForkJoinWorkerThread extends Thread {
    
    // 关联的ForkJoinPool
    final ForkJoinPool pool;
    
    // 关联的工作队列
    final ForkJoinPool.WorkQueue workQueue;
    
    /**
     * 构造方法
     * 
     * @param pool 关联的ForkJoinPool
     */
    protected ForkJoinWorkerThread(ForkJoinPool pool) {
        // 使用系统类加载器作为上下文类加载器
        super(pool.workerNamePrefix + getAndIncrementPoolThreadNumber());
        this.pool = pool;
        this.workQueue = new ForkJoinPool.WorkQueue(this);
        setDaemon(false); // 工作线程不是守护线程
    }
    
    /**
     * 获取线程池
     * 
     * @return 关联的ForkJoinPool
     */
    public ForkJoinPool getPool() {
        return pool;
    }
    
    /**
     * 获取工作队列
     * 
     * @return 关联的工作队列
     */
    public ForkJoinPool.WorkQueue getWorkQueue() {
        return workQueue;
    }
    
    /**
     * 线程启动时调用
     */
    @Override
    public void start() {
        super.start();
    }
    
    /**
     * 线程运行方法
     */
    @Override
    public void run() {
        if (workQueue.array == null) { // only true if using common pool
            Throwable exception = null;
            try {
                onStart();
                pool.runWorker(workQueue);
            } catch (Throwable ex) {
                exception = ex;
            } finally {
                try {
                    onTermination(exception);
                } catch (Throwable ex) {
                    if (exception == null)
                        exception = ex;
                } finally {
                    pool.deregisterWorker(this, exception);
                }
            }
        }
    }
    
    /**
     * 线程启动时的回调方法
     * 可以被子类重写
     */
    protected void onStart() {
    }
    
    /**
     * 线程终止时的回调方法
     * 可以被子类重写
     * 
     * @param exception 线程终止时的异常,如果没有异常则为null
     */
    protected void onTermination(Throwable exception) {
    }
    
    /**
     * 获取并递增池线程编号
     */
    private static synchronized int getAndIncrementPoolThreadNumber() {
        return poolThreadNumber++;
    }
    private static int poolThreadNumber;
}

3. 使用示例

/**
 * 递归任务示例 - 计算数组元素之和
 */
public class SumTask extends RecursiveTask<Integer> {
    private static final int THRESHOLD = 1000;
    private final int[] array;
    private final int start;
    private final int end;
    
    public SumTask(int[] array, int start, int end) {
        this.array = array;
        this.start = start;
        this.end = end;
    }
    
    @Override
    protected Integer compute() {
        // 如果任务足够小,直接计算
        if (end - start <= THRESHOLD) {
            int sum = 0;
            for (int i = start; i < end; i++) {
                sum += array[i];
            }
            return sum;
        } else {
            // 否则分解任务
            int mid = (start + end) / 2;
            SumTask leftTask = new SumTask(array, start, mid);
            SumTask rightTask = new SumTask(array, mid, end);
            
            // 异步执行左任务
            leftTask.fork();
            
            // 同步执行右任务并获取结果
            int rightResult = rightTask.compute();
            
            // 获取左任务结果
            int leftResult = leftTask.join();
            
            return leftResult + rightResult;
        }
    }
}

/**
 * 递归操作示例 - 并行处理数组
 */
public class ProcessArrayAction extends RecursiveAction {
    private static final int THRESHOLD = 1000;
    private final int[] array;
    private final int start;
    private final int end;
    
    public ProcessArrayAction(int[] array, int start, int end) {
        this.array = array;
        this.start = start;
        this.end = end;
    }
    
    @Override
    protected void compute() {
        if (end - start <= THRESHOLD) {
            // 直接处理小数组
            for (int i = start; i < end; i++) {
                array[i] = processElement(array[i]);
            }
        } else {
            // 分解大数组
            int mid = (start + end) / 2;
            ProcessArrayAction leftAction = new ProcessArrayAction(array, start, mid);
            ProcessArrayAction rightAction = new ProcessArrayAction(array, mid, end);
            
            // 并行执行两个子任务
            invokeAll(leftAction, rightAction);
        }
    }
    
    private int processElement(int element) {
        // 处理单个元素的逻辑
        return element * 2;
    }
}

/**
 * 使用示例
 */
public class ForkJoinExample {
    public static void main(String[] args) {
        // 创建ForkJoinPool
        ForkJoinPool pool = new ForkJoinPool();
        
        // 创建大数组
        int[] array = new int[1000000];
        for (int i = 0; i < array.length; i++) {
            array[i] = i;
        }
        
        // 使用递归任务计算和
        SumTask sumTask = new SumTask(array, 0, array.length);
        int sum = pool.invoke(sumTask);
        System.out.println("Sum: " + sum);
        
        // 使用递归操作处理数组
        ProcessArrayAction processAction = new ProcessArrayAction(array, 0, array.length);
        pool.invoke(processAction);
        
        // 关闭线程池
        pool.shutdown();
    }
}

4. 核心设计要点总结

4.1 工作窃取算法(Work-Stealing)

  • 每个工作线程都有自己的双端队列(deque)
  • 工作线程从自己队列的顶部获取任务(LIFO)
  • 空闲线程从其他线程队列的底部窃取任务(FIFO)
  • 减少竞争,提高并行效率

4.2 自适应调整

  • 根据工作负载动态调整线程数量
  • 线程空闲时会阻塞,有任务时自动唤醒
  • 支持任务的负载均衡

4.3 高性能设计

  • 使用无锁算法和CAS操作
  • 减少线程上下文切换
  • 优化内存访问模式

4.4 任务分解模式

  • 支持分治算法
  • 递归任务分解
  • 自动的任务调度和合并

4.5 线程池管理

  • 支持优雅关闭
  • 异常处理机制
  • 线程生命周期管理

ForkJoinPool通过这些设计,为并行计算提供了高效、灵活的解决方案,特别适合处理可以分解为多个子任务的计算密集型任务。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值