JDK 8 ForkJoinPool源码详解(详细注释版)
1. ForkJoinPool核心类源码
/*
* ForkJoinPool是Java 7引入的并行执行框架的核心类
* 实现了工作窃取算法(Work-Stealing Algorithm)
* 专门用于执行ForkJoinTask类型的任务
*/
public class ForkJoinPool extends AbstractExecutorService {
/*
* 线程池状态和配置常量
*/
// 线程池运行状态位数
private static final int COUNT_BITS = (1 << 15) - 1; // 32767
// 线程池最大容量
private static final int MAX_CAP = 0x7fff; // 32767
// 线程池状态标志
private static final int FIFO = 1 << 15; // FIFO队列标志
private static final int CLEAR_TLS = 1 << 16; // 清除TLS标志
private static final int OWNING = 1 << 17; // 拥有者标志
private static final int TERMINATED = 1 << 18; // 终止标志
private static final int STOP = 1 << 19; // 停止标志
private static final int SHUTDOWN = 1 << 20; // 关闭标志
private static final int STARTED = 1 << 21; // 已启动标志
// 默认并行级别(通常等于CPU核心数)
private static final int DEFAULT_PARALLELISM =
Math.min(MAX_CAP, Runtime.getRuntime().availableProcessors());
// 默认配置
private static final int DEFAULT_COMMON_MAX_SPARES = 256;
// 随机数生成器种子
private static final int SEED_INCREMENT = 0x9e3779b9;
private static final int PROBE_INCREMENT = 0x9e3779b9;
private static final int SECONDARY = 0x3c6ef35f;
/*
* 字段定义
*/
// 控制状态字段,包含线程池状态和工作线程数
volatile int ctl;
// 运行状态字段
volatile int runState;
// 配置字段,包含并行级别和配置标志
final int config;
// 工作队列数组,大小为2的幂
volatile WorkQueue[] workQueues;
// 工厂用于创建新的工作线程
final ForkJoinWorkerThreadFactory factory;
// 未捕获异常处理器
final UncaughtExceptionHandler ueh;
// 名称前缀
final String workerNamePrefix;
// 公共池实例
private static final ForkJoinPool common;
// 线程阻塞时使用的条件
private final Object registrationLock = new Object();
// 阻塞线程数量
volatile long stealCount;
// 饱和或终止时的等待者
private static final sun.misc.Unsafe U = sun.misc.Unsafe.getUnsafe();
private static final long CTL;
private static final long RUNSTATE;
private static final long STEALCOUNT;
private static final long PARKBLOCKER;
private static final long QTOP;
private static final long QLOCK;
private static final long QSCANSTATE;
private static final long QPARKER;
private static final int ABASE;
private static final int ASHIFT;
static {
try {
CTL = U.objectFieldOffset
(ForkJoinPool.class.getDeclaredField("ctl"));
RUNSTATE = U.objectFieldOffset
(ForkJoinPool.class.getDeclaredField("runState"));
STEALCOUNT = U.objectFieldOffset
(ForkJoinPool.class.getDeclaredField("stealCount"));
PARKBLOCKER = U.objectFieldOffset
(java.lang.Thread.class.getDeclaredField("parkBlocker"));
QTOP = U.objectFieldOffset
(WorkQueue.class.getDeclaredField("top"));
QLOCK = U.objectFieldOffset
(WorkQueue.class.getDeclaredField("lock"));
QSCANSTATE = U.objectFieldOffset
(WorkQueue.class.getDeclaredField("scanState"));
QPARKER = U.objectFieldOffset
(WorkQueue.class.getDeclaredField("parker"));
Class<?> ak = ForkJoinTask[].class;
ABASE = U.arrayBaseOffset(ak);
int scale = U.arrayIndexScale(ak);
if ((scale & (scale - 1)) != 0)
throw new Error("data type scale not a power of two");
ASHIFT = 31 - Integer.numberOfLeadingZeros(scale);
} catch (Exception e) {
throw new Error(e);
}
common = java.security.AccessController.doPrivileged(
new java.security.PrivilegedAction<ForkJoinPool>() {
public ForkJoinPool run() { return makeCommonPool(); }});
}
/**
* 工作队列类
* 每个工作线程都有一个关联的工作队列
*/
static final class WorkQueue {
// 队列锁
volatile int lock;
// 扫描状态
volatile int scanState;
// 队列顶部索引
volatile int top;
// 队列底部索引
int base;
// 队列数组
ForkJoinTask<?>[] array;
// 关联的线程
final ForkJoinWorkerThread owner;
// 队列池索引
int poolIndex;
// 随机探针值
int seed;
// 窃取者数量
volatile int stealHint;
// 线程阻塞器
volatile Thread parker;
// 等待加入的线程
volatile ForkJoinTask<?> currentJoin;
// 等待信号的线程
volatile ForkJoinTask<?> currentSteal;
// 队列构造函数
WorkQueue(ForkJoinWorkerThread owner) {
this.owner = owner;
this.array = new ForkJoinTask<?>[INITIAL_QUEUE_CAPACITY];
}
/**
* 将任务压入队列顶部(LIFO)
*/
final void push(ForkJoinTask<?> task) {
ForkJoinTask<?>[] a;
int s = top;
if ((a = array) != null) {
int m = a.length - 1;
U.putOrderedObject(a, ((m & s) << ASHIFT) + ABASE, task);
U.putOrderedInt(this, QTOP, s + 1);
}
}
/**
* 从队列顶部弹出任务(LIFO)
*/
final ForkJoinTask<?> pop() {
ForkJoinTask<?>[] a;
int s;
if ((a = array) != null && (s = top) != base) {
long j = (((a.length - 1) & --s) << ASHIFT) + ABASE;
ForkJoinTask<?> t = (ForkJoinTask<?>)U.getObject(a, j);
if (t != null) {
U.putObjectVolatile(a, j, null);
U.putOrderedInt(this, QTOP, s);
}
return t;
}
return null;
}
/**
* 从队列底部取出任务(FIFO)
*/
final ForkJoinTask<?> poll() {
ForkJoinTask<?>[] a;
int b;
if ((a = array) != null && (b = base) != top) {
int m = a.length - 1;
long j = ((m & b) << ASHIFT) + ABASE;
ForkJoinTask<?> t = (ForkJoinTask<?>)U.getObjectVolatile(a, j);
if (t != null && base == b &&
U.compareAndSwapObject(a, j, t, null)) {
base = b + 1;
return t;
}
}
return null;
}
/**
* 窃取任务
*/
final ForkJoinTask<?> steal() {
return poll();
}
/**
* 获取队列大小
*/
final int size() {
int n = top - base;
return (n < 0) ? 0 : n;
}
}
/**
* ForkJoinTask基类
* 表示可以被ForkJoinPool执行的任务
*/
public abstract class ForkJoinTask<V> implements Future<V>, Serializable {
// 任务状态
volatile int status;
// 任务完成状态常量
private static final int DONE_MASK = 0xf0000000; // 完成位掩码
private static final int NORMAL = 0xf0000000; // 正常完成
private static final int CANCELLED = 0xc0000000; // 已取消
private static final int EXCEPTIONAL = 0x80000000; // 异常完成
private static final int SIGNAL = 0x00010000; // 信号位
/**
* 执行任务的主要方法
* 子类必须实现此方法
*/
protected abstract boolean exec();
/**
* 分解任务
*/
public final ForkJoinTask<V> fork() {
Thread t;
if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) {
((ForkJoinWorkerThread)t).workQueue.push(this);
} else {
ForkJoinPool.common.externalPush(this);
}
return this;
}
/**
* 等待并获取任务结果
*/
public final V join() {
int s;
if ((s = doJoin() & DONE_MASK) != NORMAL)
reportException(s);
return getRawResult();
}
/**
* 执行任务并获取结果
*/
public final V invoke() {
int s;
if ((s = doInvoke() & DONE_MASK) != NORMAL)
reportException(s);
return getRawResult();
}
/**
* 执行连接操作
*/
private int doJoin() {
int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
return (s = status) < 0 ? s :
((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
(w = (wt = (ForkJoinWorkerThread)t).workQueue).
tryUnpush(this) && (s = doExec()) < 0 ? s :
wt.pool.awaitJoin(w, this, 0L) :
externalAwaitDone();
}
/**
* 执行调用操作
*/
private int doInvoke() {
int s; Thread t; ForkJoinWorkerThread wt;
return (s = doExec()) < 0 ? s :
((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
(wt = (ForkJoinWorkerThread)t).pool.
awaitJoin(wt.workQueue, this, 0L) :
externalAwaitDone();
}
/**
* 执行任务
*/
final int doExec() {
int s; boolean completed;
if ((s = status) >= 0) {
try {
completed = exec();
} catch (Throwable rex) {
return setExceptionalCompletion(rex);
}
if (completed)
s = setCompletion(NORMAL);
}
return s;
}
/**
* 设置完成状态
*/
private int setCompletion(int completion) {
for (int s;;) {
if ((s = status) < 0)
return s;
if (U.compareAndSwapInt(this, STATUS, s, s | completion)) {
if ((s >>> 16) != 0)
synchronized (this) { notifyAll(); }
return completion;
}
}
}
/**
* 获取原始结果
*/
public abstract V getRawResult();
/**
* 设置原始结果
*/
protected abstract void setRawResult(V value);
/**
* 取消任务
*/
public boolean cancel(boolean mayInterruptIfRunning) {
return (setCompletion(CANCELLED) & DONE_MASK) == CANCELLED;
}
/**
* 判断任务是否完成
*/
public final boolean isDone() {
return status < 0;
}
/**
* 判断任务是否被取消
*/
public final boolean isCancelled() {
return (status & DONE_MASK) == CANCELLED;
}
/**
* 获取任务结果
*/
public final V get() throws InterruptedException, ExecutionException {
int s = (Thread.currentThread() instanceof ForkJoinWorkerThread) ?
doJoin() : externalInterruptibleAwaitDone();
Throwable ex;
if ((s &= DONE_MASK) == CANCELLED)
throw new CancellationException();
if (s == EXCEPTIONAL && (ex = getThrowableException()) != null)
throw new ExecutionException(ex);
return getRawResult();
}
/**
* 带超时的获取任务结果
*/
public final V get(long timeout, TimeUnit unit)
throws InterruptedException, ExecutionException, TimeoutException {
if (Thread.interrupted())
throw new InterruptedException();
// 简化实现,实际实现更复杂
return get();
}
}
/**
* 递归任务抽象类
* 用于实现分治算法
*/
public abstract class RecursiveTask<V> extends ForkJoinTask<V> {
private static final long serialVersionUID = 5232453952276485270L;
V result;
protected abstract V compute();
@Override
public final V getRawResult() {
return result;
}
@Override
protected final void setRawResult(V value) {
result = value;
}
@Override
protected final boolean exec() {
result = compute();
return true;
}
}
/**
* 递归操作抽象类
* 用于执行不返回结果的操作
*/
public abstract class RecursiveAction extends ForkJoinTask<Void> {
private static final long serialVersionUID = 5232453952276485270L;
protected abstract void compute();
@Override
public final Void getRawResult() { return null; }
@Override
protected final void setRawResult(Void mustBeNull) { }
@Override
protected final boolean exec() {
compute();
return true;
}
}
/*
* ForkJoinPool核心方法实现
*/
/**
* 构造方法 - 创建指定并行级别的ForkJoinPool
*
* @param parallelism 并行级别
*/
public ForkJoinPool(int parallelism) {
this(parallelism, defaultForkJoinWorkerThreadFactory, null, false);
}
/**
* 构造方法 - 创建完整配置的ForkJoinPool
*
* @param parallelism 并行级别
* @param factory 工作线程工厂
* @param handler 未捕获异常处理器
* @param asyncMode 异步模式标志
*/
public ForkJoinPool(int parallelism,
ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler,
boolean asyncMode) {
this(checkParallelism(parallelism),
checkFactory(factory),
handler,
asyncMode ? FIFO : 0,
"ForkJoinPool-" + nextPoolId() + "-worker-");
}
/**
* 私有构造方法
*/
private ForkJoinPool(int parallelism,
ForkJoinWorkerThreadFactory factory,
UncaughtExceptionHandler handler,
int mode,
String workerNamePrefix) {
this.workerNamePrefix = workerNamePrefix;
this.factory = factory;
this.ueh = handler;
this.config = (parallelism & 0xffff) | mode;
long np = (long)(-parallelism); // offset ctl counts
this.ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK);
}
/**
* 检查并行级别参数
*/
private static int checkParallelism(int parallelism) {
if (parallelism <= 0 || parallelism > MAX_CAP)
throw new IllegalArgumentException();
return parallelism;
}
/**
* 检查工厂参数
*/
private static ForkJoinWorkerThreadFactory checkFactory
(ForkJoinWorkerThreadFactory factory) {
if (factory == null)
throw new NullPointerException();
return factory;
}
/**
* 获取下一个池ID
*/
private static synchronized int nextPoolId() {
return ++poolNumberSequence;
}
private static int poolNumberSequence;
/**
* 获取公共池实例
*/
public static ForkJoinPool commonPool() {
return common;
}
/**
* 创建公共池
*/
private static ForkJoinPool makeCommonPool() {
int parallelism = DEFAULT_PARALLELISM;
ForkJoinWorkerThreadFactory factory = null;
UncaughtExceptionHandler handler = null;
try {
String pp = System.getProperty
("java.util.concurrent.ForkJoinPool.common.parallelism");
String fp = System.getProperty
("java.util.concurrent.ForkJoinPool.common.threadFactory");
String hp = System.getProperty
("java.util.concurrent.ForkJoinPool.common.exceptionHandler");
if (pp != null)
parallelism = Integer.parseInt(pp);
if (fp != null)
factory = ((ForkJoinWorkerThreadFactory)ClassLoader.
getSystemClassLoader().loadClass(fp).newInstance());
if (hp != null)
handler = ((UncaughtExceptionHandler)ClassLoader.
getSystemClassLoader().loadClass(hp).newInstance());
} catch (Exception ignore) {
}
if (factory == null) {
if (System.getSecurityManager() == null)
factory = defaultForkJoinWorkerThreadFactory;
else // use security-managed default
factory = new InnocuousForkJoinWorkerThreadFactory();
}
return new ForkJoinPool(parallelism, factory, handler, LIFO_QUEUE,
"ForkJoinPool.commonPool-worker-");
}
/**
* 默认工作线程工厂
*/
static final ForkJoinWorkerThreadFactory defaultForkJoinWorkerThreadFactory;
static {
defaultForkJoinWorkerThreadFactory =
new DefaultForkJoinWorkerThreadFactory();
}
static final class DefaultForkJoinWorkerThreadFactory
implements ForkJoinWorkerThreadFactory {
public final ForkJoinWorkerThread newThread(ForkJoinPool pool) {
return new ForkJoinWorkerThread(pool);
}
}
/**
* 执行ForkJoinTask
*/
@Override
public void execute(ForkJoinTask<?> task) {
if (task == null)
throw new NullPointerException();
externalPush(task);
}
/**
* 提交ForkJoinTask
*/
@Override
public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
if (task == null)
throw new NullPointerException();
externalPush(task);
return task;
}
/**
* 提交Callable任务
*/
@Override
public <T> ForkJoinTask<T> submit(Callable<T> task) {
return new ForkJoinTask.AdaptedCallable<T>(task).fork();
}
/**
* 提交Runnable任务
*/
@Override
public ForkJoinTask<?> submit(Runnable task) {
return new ForkJoinTask.AdaptedRunnable<Void>(task, null).fork();
}
/**
* 提交Runnable任务并指定结果
*/
@Override
public <T> ForkJoinTask<T> submit(Runnable task, T result) {
return new ForkJoinTask.AdaptedRunnable<T>(task, result).fork();
}
/**
* 外部任务推送
*/
final void externalPush(ForkJoinTask<?> task) {
WorkQueue[] ws; WorkQueue q; int m;
int r = ThreadLocalRandom.getProbe();
int am = (config & FIFO) | r; // mode bits
if ((ws = workQueues) != null && (m = ws.length - 1) >= 0 &&
(q = ws[m & r & SQMASK]) != null && q.base == q.top) { // very conservative
q.push(task);
if (q.base != q.top)
signalWork(ws, q);
}
else
externalSubmit(task);
}
/**
* 外部任务提交
*/
private void externalSubmit(ForkJoinTask<?> task) {
int r; // initialize caller's probe
if ((r = ThreadLocalRandom.getProbe()) == 0) {
ThreadLocalRandom.localInit();
r = ThreadLocalRandom.getProbe();
}
for (;;) {
WorkQueue[] ws; WorkQueue q; int rs, m, k;
boolean move = false;
if ((rs = runState) < 0) {
tryTerminate(false, false); // help terminate
throw new RejectedExecutionException();
}
else if ((ws = workQueues) != null && (m = ws.length - 1) >= 0) {
if ((q = ws[k = r & m]) == null) {
WorkQueue nq = new WorkQueue(null);
nq.seed = r;
nq.hint = r; // publication racy
int i = (((r << 2) ^ r) << 17) ^ r; // xorshift
if (U.compareAndSwapObject(ws, ((long)k << ASHIFT) + ABASE, null, nq))
move = true;
}
else if (!U.compareAndSwapInt(q, QLOCK, 0, 1))
move = true; // move and restart
else {
if (q.base == q.top)
q.push(task);
else
q.transferStealCount(this);
q.lock = 0;
if (q.base != q.top)
signalWork(ws, q);
break;
}
}
else if ((rs & STARTED) == 0 || (ws = workQueues) == null)
tryInitializeWorker();
if (move)
r = ThreadLocalRandom.advanceProbe(r);
}
}
/**
* 信号工作线程
*/
final void signalWork(WorkQueue[] ws, WorkQueue q) {
long c; int sp, i; WorkQueue v; Thread p;
while ((c = ctl) < 0L) { // too few active
if ((sp = (int)c) == 0) { // no idle workers
if ((c & ADD_WORKER) != 0L) // too few total
createWorker();
break;
}
if (ws == null || ws.length <= (i = sp & SMASK))
break;
if ((v = ws[i]) == null)
break;
int vs = (sp + SS_SEQ) & ~INACTIVE; // next scanState
int vc = (int)(c >> AC_SHIFT) + (v.scanState < 0 ? 0 : 1);
long nc = (v.stackPred & SMASK) | ((long)vc << AC_SHIFT) |
((long)(v.nextWait & E_MASK) << TC_SHIFT) |
((long)(v.seed & 0xffffffffL) << 32);
if (U.compareAndSwapLong(this, CTL, c, nc)) {
v.scanState = vs;
if ((p = v.parker) != null)
U.unpark(p);
break;
}
}
}
/**
* 创建工作线程
*/
private boolean createWorker() {
ForkJoinWorkerThreadFactory fac = factory;
Throwable ex = null;
ForkJoinWorkerThread wt = null;
try {
if (fac != null && (wt = fac.newThread(this)) != null) {
wt.start();
return true;
}
} catch (Throwable rex) {
ex = rex;
}
deregisterWorker(wt, ex);
return false;
}
/**
* 注销工作线程
*/
final void deregisterWorker(ForkJoinWorkerThread wt, Throwable ex) {
WorkQueue w = null;
if (wt != null && (w = wt.workQueue) != null) {
int idx = w.poolIndex;
w.poolIndex = 0; // invalidate
WorkQueue[] ws = workQueues;
if (ws != null && idx >= 0 && idx < ws.length)
ws[idx] = null; // remove index
}
long c; // decrement counts
do {} while (!U.compareAndSwapLong
(this, CTL, c = ctl, ((c & ~TC_MASK) |
(((c & TC_MASK) - TC_UNIT) & TC_MASK))));
if (w != null) {
synchronized (this) {
if (w.array != null) { // check if previously enabled
w.array = null; // disable
if (!w.queuing) // remove from registry
deregisterWorker(w);
}
}
}
if (ex == null) // help clean on way out
ForkJoinTask.helpExpungeStaleExceptions();
else // rethrow
U.throwException(ex);
}
/**
* 等待任务完成
*/
final int awaitJoin(WorkQueue w, ForkJoinTask<?> task, long deadline) {
int s = 0;
if (task != null && w != null) {
ForkJoinTask<?> prevJoin = w.currentJoin;
w.currentJoin = task;
if (task.status >= 0) {
if (w.base == w.top || (s = tryHelpStealer(w, task)) >= 0)
s = w.qlock == 0 ? task.doJoin() : 0;
}
w.currentJoin = prevJoin;
}
return s;
}
/**
* 尝试帮助窃取者
*/
private int tryHelpStealer(WorkQueue w, ForkJoinTask<?> task) {
int oldSum = 0;
WorkQueue[] ws;
if ((ws = workQueues) != null) {
for (int retries = 0, sum = 0;;) {
WorkQueue u = w; // extract ordering for lock
int i = w.seed & (ws.length - 1);
WorkQueue v = ws[i];
if (v != null && v != w && v.base != v.top) {
if (u.qlock == 0 && v.qlock == 0 &&
U.compareAndSwapInt(u, QLOCK, 0, 1)) {
ForkJoinTask<?>[] a = v.array;
int b = v.base, m = a != null ? a.length - 1 : 0;
if (v.base == b && a != null && m > 0) {
ForkJoinTask<?> t = (ForkJoinTask<?>)
U.getObjectVolatile(a, ((long)m & b) << ASHIFT);
if (t != null && v.base == b &&
U.compareAndSwapObject(a, ((long)m & b) << ASHIFT,
t, null)) {
v.base = b + 1;
u.qlock = 0;
w.runTask(t);
return 0;
}
}
u.qlock = 0;
}
}
if (++sum == (oldSum << 1) + 1) {
if ((oldSum = sum) == 0)
return 0;
if (retries++ > 100)
break;
}
}
}
return -1;
}
/**
* 运行任务
*/
final void runTask(ForkJoinTask<?> task) {
if (task != null) {
(currentForkJoinTask = task).doExec();
currentForkJoinTask = null;
}
}
/**
* 获取活动线程数
*/
public int getActiveThreadCount() {
int c = ctl;
int ac = (short)(c >>> AC_SHIFT);
return (ac >= 0) ? ac : 0;
}
/**
* 判断是否处于静止状态
*/
public boolean isQuiescent() {
return (ctl & AC_MASK) == 0;
}
/**
* 获取窃取计数
*/
public long getStealCount() {
long count = stealCount;
WorkQueue[] ws; WorkQueue w;
if ((ws = workQueues) != null) {
for (int i = 0; i < ws.length; ++i) {
if ((w = ws[i]) != null)
count += w.nsteals;
}
}
return count;
}
/**
* 获取并行级别
*/
public int getParallelism() {
int par = config & 0xffff;
return (par > 0) ? par : 1;
}
/**
* 获取公共池的并行级别
*/
public static int getCommonPoolParallelism() {
return commonParallelism;
}
private static final int commonParallelism =
(short)(common != null ? common.config : 0);
/**
* 关闭线程池
*/
@Override
public void shutdown() {
final ReentrantLock lock = this.submissionQueuesLock;
lock.lock();
try {
if ((runState & SHUTDOWN) == 0) {
runState |= SHUTDOWN;
terminate(false);
}
} finally {
lock.unlock();
}
}
/**
* 立即关闭线程池
*/
@Override
public List<Runnable> shutdownNow() {
final ReentrantLock lock = this.submissionQueuesLock;
lock.lock();
try {
if ((runState & (STOP | TERMINATED)) == 0) {
runState |= STOP;
for (int i = 0, n = workQueues.length; i < n; ++i) {
WorkQueue w = workQueues[i];
if (w != null)
w.cancelAll();
}
terminate(false);
}
} finally {
lock.unlock();
}
return Collections.emptyList();
}
/**
* 判断线程池是否已关闭
*/
@Override
public boolean isShutdown() {
return runState >= 0;
}
/**
* 判断线程池是否已终止
*/
@Override
public boolean isTerminated() {
return (runState & TERMINATED) != 0;
}
/**
* 等待线程池终止
*/
@Override
public boolean awaitTermination(long timeout, TimeUnit unit)
throws InterruptedException {
long nanos = unit.toNanos(timeout);
final ReentrantLock lock = this.submissionQueuesLock;
lock.lock();
try {
for (;;) {
if ((runState & TERMINATED) != 0)
return true;
if (nanos <= 0)
return false;
nanos = termination.awaitNanos(nanos);
}
} finally {
lock.unlock();
}
}
/**
* 终止线程池
*/
private void terminate(boolean now) {
runState |= (now ? STOP : SHUTDOWN);
signalWork();
for (int i = 0, n = workQueues.length; i < n; ++i) {
WorkQueue w = workQueues[i];
if (w != null)
w.cancelAll();
}
runState |= TERMINATED;
termination.signalAll();
}
}
2. ForkJoinWorkerThread工作线程类
/*
* ForkJoinWorkerThread是ForkJoinPool的工作线程实现
* 每个工作线程都关联一个工作队列
*/
public class ForkJoinWorkerThread extends Thread {
// 关联的ForkJoinPool
final ForkJoinPool pool;
// 关联的工作队列
final ForkJoinPool.WorkQueue workQueue;
/**
* 构造方法
*
* @param pool 关联的ForkJoinPool
*/
protected ForkJoinWorkerThread(ForkJoinPool pool) {
// 使用系统类加载器作为上下文类加载器
super(pool.workerNamePrefix + getAndIncrementPoolThreadNumber());
this.pool = pool;
this.workQueue = new ForkJoinPool.WorkQueue(this);
setDaemon(false); // 工作线程不是守护线程
}
/**
* 获取线程池
*
* @return 关联的ForkJoinPool
*/
public ForkJoinPool getPool() {
return pool;
}
/**
* 获取工作队列
*
* @return 关联的工作队列
*/
public ForkJoinPool.WorkQueue getWorkQueue() {
return workQueue;
}
/**
* 线程启动时调用
*/
@Override
public void start() {
super.start();
}
/**
* 线程运行方法
*/
@Override
public void run() {
if (workQueue.array == null) { // only true if using common pool
Throwable exception = null;
try {
onStart();
pool.runWorker(workQueue);
} catch (Throwable ex) {
exception = ex;
} finally {
try {
onTermination(exception);
} catch (Throwable ex) {
if (exception == null)
exception = ex;
} finally {
pool.deregisterWorker(this, exception);
}
}
}
}
/**
* 线程启动时的回调方法
* 可以被子类重写
*/
protected void onStart() {
}
/**
* 线程终止时的回调方法
* 可以被子类重写
*
* @param exception 线程终止时的异常,如果没有异常则为null
*/
protected void onTermination(Throwable exception) {
}
/**
* 获取并递增池线程编号
*/
private static synchronized int getAndIncrementPoolThreadNumber() {
return poolThreadNumber++;
}
private static int poolThreadNumber;
}
3. 使用示例
/**
* 递归任务示例 - 计算数组元素之和
*/
public class SumTask extends RecursiveTask<Integer> {
private static final int THRESHOLD = 1000;
private final int[] array;
private final int start;
private final int end;
public SumTask(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected Integer compute() {
// 如果任务足够小,直接计算
if (end - start <= THRESHOLD) {
int sum = 0;
for (int i = start; i < end; i++) {
sum += array[i];
}
return sum;
} else {
// 否则分解任务
int mid = (start + end) / 2;
SumTask leftTask = new SumTask(array, start, mid);
SumTask rightTask = new SumTask(array, mid, end);
// 异步执行左任务
leftTask.fork();
// 同步执行右任务并获取结果
int rightResult = rightTask.compute();
// 获取左任务结果
int leftResult = leftTask.join();
return leftResult + rightResult;
}
}
}
/**
* 递归操作示例 - 并行处理数组
*/
public class ProcessArrayAction extends RecursiveAction {
private static final int THRESHOLD = 1000;
private final int[] array;
private final int start;
private final int end;
public ProcessArrayAction(int[] array, int start, int end) {
this.array = array;
this.start = start;
this.end = end;
}
@Override
protected void compute() {
if (end - start <= THRESHOLD) {
// 直接处理小数组
for (int i = start; i < end; i++) {
array[i] = processElement(array[i]);
}
} else {
// 分解大数组
int mid = (start + end) / 2;
ProcessArrayAction leftAction = new ProcessArrayAction(array, start, mid);
ProcessArrayAction rightAction = new ProcessArrayAction(array, mid, end);
// 并行执行两个子任务
invokeAll(leftAction, rightAction);
}
}
private int processElement(int element) {
// 处理单个元素的逻辑
return element * 2;
}
}
/**
* 使用示例
*/
public class ForkJoinExample {
public static void main(String[] args) {
// 创建ForkJoinPool
ForkJoinPool pool = new ForkJoinPool();
// 创建大数组
int[] array = new int[1000000];
for (int i = 0; i < array.length; i++) {
array[i] = i;
}
// 使用递归任务计算和
SumTask sumTask = new SumTask(array, 0, array.length);
int sum = pool.invoke(sumTask);
System.out.println("Sum: " + sum);
// 使用递归操作处理数组
ProcessArrayAction processAction = new ProcessArrayAction(array, 0, array.length);
pool.invoke(processAction);
// 关闭线程池
pool.shutdown();
}
}
4. 核心设计要点总结
4.1 工作窃取算法(Work-Stealing)
- 每个工作线程都有自己的双端队列(deque)
- 工作线程从自己队列的顶部获取任务(LIFO)
- 空闲线程从其他线程队列的底部窃取任务(FIFO)
- 减少竞争,提高并行效率
4.2 自适应调整
- 根据工作负载动态调整线程数量
- 线程空闲时会阻塞,有任务时自动唤醒
- 支持任务的负载均衡
4.3 高性能设计
- 使用无锁算法和CAS操作
- 减少线程上下文切换
- 优化内存访问模式
4.4 任务分解模式
- 支持分治算法
- 递归任务分解
- 自动的任务调度和合并
4.5 线程池管理
- 支持优雅关闭
- 异常处理机制
- 线程生命周期管理
ForkJoinPool通过这些设计,为并行计算提供了高效、灵活的解决方案,特别适合处理可以分解为多个子任务的计算密集型任务。

377

被折叠的 条评论
为什么被折叠?



