package com.test.thread.forkjoin;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
/**
* Created by jl on 2018/8/31 0031
* fork join框架采用分治法将任务分解从而提高整体任务执行效率
* forkJoinTask.fork()方法用于将任务继续拆分或执行,如果是拆分,当前线程就充当了监工无法被分配到任务
* forkJoinTask.invokeAll(task1,task2,...)方法是针对fork方法的优化,被invoke的n个任务中会将第一个任务留给当前线程
* 去执行,这样递归循环下去从而保证所有的线程都充当工人角色(有些线程既是工人也是监工,没有只是监工的线程)
* forkJoinTask.join()方法用于返回任务执行的结果
*/
public class Demo {
public static void main(String[] args) throws ExecutionException, InterruptedException {
System.out.println(Thread.currentThread().getName());
ForkJoinPool forkJoinPool = new ForkJoinPool(5);
CountTask task = new CountTask(1, 6);
Future<Integer> result = forkJoinPool.submit(task);
System.out.println(result.get());
}
}
package com.test.thread.forkjoin;
import java.util.concurrent.RecursiveTask;
/**
* Created by jl on 2018/8/31 0031
* 用于计算连续自然数相加的任务
*/
public class CountTask extends RecursiveTask<Integer> {
//阈值,任务大小低于该值则结束拆分
private static final int THRESHOLD = 2;
private int start;
private int end;
public CountTask(int start, int end) {
this.start = start;
this.end = end;
}
protected Integer compute() {
int sum = 0;
boolean canCompute = (end - start) <= THRESHOLD;
if (canCompute) {
System.out.println(Thread.currentThread().getName() + "(工人)");
// 模拟工作耗时
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
//进行运算
for (int i = start; i <= end; i++) {
sum += i;
}
} else {
System.out.println(Thread.currentThread().getName() + "(监工)");
//如果大于阈值,就进行任务拆分
int middle = (start + end) / 2;
CountTask leftTask = new CountTask(start, middle);
CountTask rightTask = new CountTask(middle + 1, end);
//执行子任务
// leftTask.fork();
// rightTask.fork();
// invokeAll的N个任务中,其中N-1个任务会使用fork()交给其它线程执行,
// 但是,它还会留一个任务自己执行,这样,就充分利用了线程池,保证没有空闲的线程。
invokeAll(leftTask, rightTask);
//等待子任务执行完成,得到执行结果
int leftResult = leftTask.join();
int rightResult = rightTask.join();
//合并子任务结果
sum = leftResult + rightResult;
}
return sum;
}
}