前提
想想一下,如果一项任务要处理的数据比较多,或者耗时较长,我们会怎么做呢?很容易回想到 ——“多线程”。
那么再想一想,使用多线程又会存在两个问题。1. 多线程在主线程中如何知道多个线程何时执行结束呢?2. 多线程中的变量如何保证其安全性。
这里我们主要讨论第一个问题,那就是我们如何得知多个线程何时执行结束,对于这个问题我们设想一个应用场景,比如说有一个流程需要一个步骤一个步骤来,但是某一步耗时较长,虽然这一个步骤执行的时间较长,但是必须要等到这一步执行结束之后才能进行下一步的操作。如果我们只是简单的 使用Runnable 或者是使用线城池将Runnable对象丢进去执行,那么由于多线程的关系,是无法保证其执行顺序的。
所幸,java给我们提供了一个多线程的工具类 —— “CountDownLatch”, 这个类的作用就相当于一个计数器,当这个计时器没有执行结束的时候会使得线程等待其他线程执行,只有等到计时器执行结束才让线程继续执行。
CountDownLatch
其中CountDownLatch提供了两个主要的方法
- ‘await()’
在要进行等待的线程中执行,执行await的线程会被挂起,直到 CountDownLatch 这个计数器计数结束为止 - ‘countdown’ 在被等待的线程中使用,没被调用一次,计数器就减一,直到计数器的值被减到0的时候,调用await的线程才被执行
从CountDownLatch提供的两个方法可以看出,使用它我们可以很方便的控制线程执行的顺序,比如说制造商品,买原材料一个人足够,但是制作商品需要好几个人来完成,最后的商品运输出去也只需要一个人来执行,但是商品运输需要等待商品制作完成的,不可能商品还没制作结束就进行运输。因此通过对CountDownLatch的使用我们希望的是能把商品制作这件事给分割开来,而又不希望在商品制作完成前就执行商品运输的事情。
因此,下面对CountDownLatch进行简单的封装,我们希望封装好的CountDownLatch能做到下面的两件事
- 使用多线程加快耗时较长的任务的执行速度
- 使用多线程而不影响整体流程的进行,不会出现流程顺序不正确的问题。
下面来看具体的代码实现:
1. 分割任务
将任务进行分割要怎么分割,这里提供一个最简单的方式,就是按照任务量进行分割,这里使用List作为任务要处理的数据,分割也就是对将一个List分割为几个同等大小的List。代码如下:
/**
* 分割list为多个子list
* @param list 要进行分割的list
* @param childNum 每个list的中item的数量大小
* @return 分割后的所有list组合
*/
public static <E> List<List<E>> split(List<E> list, int childNum) {
if (childNum <= 0) {
throw new IllegalArgumentException("Param childNum can't less than zero!");
}
if (isNotEmpty(list)) {
int initialCapacity = (list.size() % childNum == 0) ? list.size() / childNum : list.size() / childNum + 1;
List<List<E>> result = new ArrayList<List<E>>();
for (int i = 0; i < initialCapacity; i++) {
result.add(list.subList(i * childNum, Math.min((i + 1) * childNum, list.size())));
}
return result;
}
return Collections.emptyList();
}
上面的代码很简答,就是将主任务的List按照每个子任务的任务数量进行分割。
2. 使用多线程
接下来就需要使用CountDownLatch来进行对多线程进行控制了,怎么进行使用呢?
- 首先我们要将我们要处理 的数据根据第一步的步骤分割成等大小的子任务。
- 其次我们使用CountDownLatch的await方法让要等待的线程进行等待。
- 最后再子任务执行结束之后使用CountDownLatch 的 countDown方法减少计数
这样我们就可以在不影响正确流程的情况下使用多线程处理数据了,代码如下:
/**
* 将任务拆分为子任务使用线程进行执行,加快任务分析速度
*
* @param childTaskReses 子任务要处理的数据列表
* @param childTaskLogic 每个子任务的处理逻辑
* @param errorHandler 子任务执行异常的回调方法
* @param progressHandler 执行进度回调方法
*/
public <E> void splitTaskExec(List<List<E>> childTaskReses, ObjIntConsumer<List<E>> childTaskLogic,ObjIntConsumer<Exception> errorHandler, DoubleConsumer progressHandler) {
int taskCount = childTaskReses.size();
// 控制多个子线程开始的计数器
CountDownLatch begin = new CountDownLatch(1);
// 控制多个子线程结束的计数器,计数器初始数为子任务的数量
CountDownLatch end = new CountDownLatch(taskCount);
for (int i = 0; i < childTaskReses.size(); i++) {
// 如果存在线程嵌套问题,那么把线程丢进Cache线程池中,防止阻塞发生
List<E> childTaskRe = childTaskReses.get(i);
// 记录子任务的位置
int location = i;
Runnable runnable = () -> {
try {
begin.await();
childTaskLogic.accept(childTaskRe, location);
} catch (InterruptedException e) {
System.err.println("子任务执行异常:" + Thread.currentThread().getName() + "->" + e.getMessage());
if (errorHandler != null) {
errorHandler.accept(e, location);
}
} finally {
// 自任务执行结束后减少计数
end.countDown();
if (progressHandler != null) {
progressHandler.accept((taskCount - end.getCount()) * 1.0 / taskCount);
}
}
};
// 使用线程池执行
getExecutor().submit(runnable);
}
try {
begin.countDown();
// 调用分割子任务的线程等待计数器执行结束(等待所有子任务执行结束)
end.await();
} catch (InterruptedException e) {
System.err.println("主任务执行异常:" + e.getMessage());
}
}
在上面的逻辑比较简单,第一个参数就是使用我们第一步的分割方法分割的子任务列表,然后我们创建一个和子任务数量相当的计数器(CountDownLatch),我们为每一个子任务生成一个Runnable对象,用于执行子任务的逻辑,其中子任务的任务逻辑通过第二个参数 childTaskLogic 交给用户去进行实现,由于子任务要怎么进行处理我们并不清楚,所以使用回调交给用户自己实现是最好的方法。同时这里还添加了后面的两个参数errorHandler和progressHandler用来让用户处理子任务处理失败之后的逻辑以及获取任务处理进度。
细心地人估计能看到上面有两个CountDownLatch,那个end比较好理解,就是我们上面说的等待计数的作用,那么那个begin呢?可以看到它的计数只有1,它的作用是要等待所有线程都初始化了,添加进了线程池之后统一进行执行。当然去掉它将所有的子线程放在一个列表里面统一加入到线程池也可以,这里只是一种比较简单的实现方式。
3. 优化
为了使用方便,我们可以提供一些重载的函数,例如去掉错误处理逻辑和执行进度逻辑,又例如将第一步的流程和第二步的流程合并到一起,通过执行主任务和子任务数量执行。代码如下:
/**
* 将任务拆分为子任务使用线程进行执行,加快任务分析速度
*
* @param res 要拆分的总任务
* @param childTaskItemNum 每个子任务中的item数量
* @param childTaskLogic 每个子任务的处理逻辑
*/
public <E> void splitTaskExec(List<E> res, int childTaskItemNum, ObjIntConsumer<List<E>> childTaskLogic) {
List<List<E>> childTaskRes = CollectionUtil.split(res, childTaskItemNum);
splitTaskExec(childTaskRes, childTaskLogic, null, null);
}
/**
* 将任务拆分为子任务使用线程进行执行,加快任务分析速度
*
* @param res 要拆分的总任务
* @param childTaskItemNum 每个子任务中的item数量
* @param childTaskLogic 每个子任务的处理逻辑
*/
public <E> void splitTaskExec(List<E> res, int childTaskItemNum, ObjIntConsumer<List<E>> childTaskLogic,
ObjIntConsumer<Exception> exceptionHandler, DoubleConsumer processHandler) {
List<List<E>> childTaskRes = CollectionUtil.split(res, childTaskItemNum);
splitTaskExec(childTaskRes, childTaskLogic, exceptionHandler, processHandler);
}
这样我们就可以使用了,而且使用起来也比较方便,甚至我们还可以嵌套进行使用,例如:
// 非嵌套使用,每个子任务有50个元素,这里被分成了 200/50=4个子任务
splitTaskExec(TestUtil.getArangeList(200), 50, (ls,index) -> {
//模拟子任务耗时操作
long start = System.currentTimeMillis();
System.out.println("j=" + j);
while (System.currentTimeMillis() - start < 1000) {
}
});
// 嵌套使用
splitTaskExec(TestUtil.getArangeList(50), 1, (lss, index) -> {// 这里分了50个子任务
List<String> res = TestUtil.getArangeList(200);
// 这里在每个之人物里面又分了100个子任务
splitTaskExec(res, 2, (ls,index2) -> {
long start = System.currentTimeMillis();
while (System.currentTimeMillis() - start < 1000) {
}
});
});
4. 线程池
上面提到了,我们使用了线程池对线程进行管理,那么要使用那种线程池呢?有两种选择一个是固定大小的,例如Executors.newFixedThreadPool().一种是没有固定大小的,例如 Executors.newCachedThreadPool(), 如果使用第一种,在进行嵌套使用的时候线程数过多可能会阻塞卡死,因为外面的主任务在等待子任务执行完毕,而外面的主任务已经将线城池沾满了无法执行子任务就会导致相互等待的情况出现,如果使用CachedThreadPool的话就不会出现这个问题,但是个人感觉线程数太多也不是什么好事,因此我的方法是这两种线程都使用,当然如果只是用CachedThreadPool也不会有什么问题,另外如果不嵌套使用也不会有问题,我的逻辑是这样的。
判断当前子任务的线程的父线程是什么,如果是嵌套线程,如果父线程是Fix的添加进Cache中,如果是Cache的添加进Fix中.使得嵌套的线程之间相互不影响。那么怎么判断是那种线程呢?这里就需要在创建线程池的时候添加ThreadFactory参数了,这个参数可以让我们自定义线程,这样我们就可以给我们自己的线程定制名字,就可以进行按断了。这里我仿照Java源码中的写了自己的实现。
private static class TUThreadFactory implements ThreadFactory {
private final ThreadGroup group;
private final AtomicInteger threadNumber = new AtomicInteger(1);
private final static String namePrefix = "TUTheadPool";
private final String _namePrefix;
public TUThreadFactory(String poolName) {
SecurityManager s = System.getSecurityManager();
group = (s != null) ? s.getThreadGroup() : Thread.currentThread().getThreadGroup();
_namePrefix = namePrefix + "-" + poolName + "-thread-";
}
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(group, r, _namePrefix + threadNumber.getAndIncrement(), 0);
if (t.isDaemon()) {
t.setDaemon(false);
}
if (t.getPriority() != Thread.NORM_PRIORITY) {
t.setPriority(Thread.NORM_PRIORITY);
}
return t;
}
}
上面的逻辑就相当于给每一个添加进来的任务都制定了一个自定义的名字,这样就方便我么知道当前是在哪一种线程池中了。
创建线程池的代码如下:
int processorsCount = Runtime.getRuntime().availableProcessors() + 1;
this.fixRoundExecutor =(ThreadPoolExecutor)Executors.newFixedThreadPool(processorsCount*2,new TUThreadFactory("fixed"));
this.cacheExecutor = Executors.newCachedThreadPool(new TUThreadFactory("cache"));
上面生成了两种线程池,一个是固定大小的,一种是无限大小的。
怎么确定使用哪个呢?由于使用了ThreadFactory,只需判断线程前缀即可,代码如下:只要固定大小的线程池没有满就优先添加进固定的线程池中
public ExecutorService getExecutor() {
if (Thread.currentThread().getName().startsWith(TUThreadFactory.namePrefix)
&& fixRoundExecutor.getTaskCount() > fixRoundExecutor.getMaximumPoolSize()) {
return cacheExecutor;
} else {
return fixRoundExecutor;
}
}
至此,所有的代码书写完毕,这里附上完整的代码,也可以去github下载:
查看源码
完整代码
package com.mengfly.lib.concurrent;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.DoubleConsumer;
import java.util.function.ObjIntConsumer;
import com.mengfly.lib.CollectionUtil;
public class TaskUtil {
private ThreadPoolExecutor fixRoundExecutor;
private ExecutorService cacheExecutor;
private static volatile TaskUtil instance;
private TaskUtil() {
// 创建线程池
int processorsCount = Runtime.getRuntime().availableProcessors() + 1;
// this.fixRoundExecutor = new ThreadPoolExecutor(processorsCount, processorsCount * 2, 60L, TimeUnit.SECONDS,
// new LinkedBlockingDeque<>(), new TUThreadFactory("fixed"));
this.fixRoundExecutor = (ThreadPoolExecutor) Executors.newFixedThreadPool(processorsCount,
new TUThreadFactory("fixed"));
this.cacheExecutor = Executors.newCachedThreadPool(new TUThreadFactory("cache"));
}
/**
* 将任务拆分为子任务使用线程进行执行,加快任务分析速度
*
* @param childTaskReses 子任务要处理的数据列表
* @param childTaskLogic 每个子任务的处理逻辑
*/
public <E> void splitTaskExec(List<List<E>> childTaskReses, ObjIntConsumer<List<E>> childTaskLogic,
ObjIntConsumer<Exception> errorHandler, DoubleConsumer progressHandler) {
int taskCount = childTaskReses.size();
CountDownLatch begin = new CountDownLatch(1);
CountDownLatch end = new CountDownLatch(taskCount);
for (int i = 0; i < childTaskReses.size(); i++) {
// 如果存在线程嵌套问题,那么把线程丢进Cache线程池中,防止阻塞发生
List<E> childTaskRe = childTaskReses.get(i);
int location = i;
Runnable runnable = () -> {
try {
begin.await();
childTaskLogic.accept(childTaskRe, location);
} catch (InterruptedException e) {
System.err.println("子任务执行异常:" + Thread.currentThread().getName() + "->" + e.getMessage());
if (errorHandler != null) {
errorHandler.accept(e, location);
}
} finally {
end.countDown();
if (progressHandler != null) {
progressHandler.accept((taskCount - end.getCount()) * 1.0 / taskCount);
}
}
};
getExecutor().submit(runnable);
}
try {
begin.countDown();
end.await();
} catch (InterruptedException e) {
System.err.println("主任务执行异常:" + e.getMessage());
}
}
/**
* 将任务拆分为子任务使用线程进行执行,加快任务分析速度
*
* @param res 要拆分的总任务
* @param childTaskItemNum 每个子任务中的item数量
* @param childTaskLogic 每个子任务的处理逻辑
*/
public <E> void splitTaskExec(List<E> res, int childTaskItemNum, ObjIntConsumer<List<E>> childTaskLogic) {
List<List<E>> childTaskRes = CollectionUtil.split(res, childTaskItemNum);
splitTaskExec(childTaskRes, childTaskLogic, null, null);
}
/**
* 将任务拆分为子任务使用线程进行执行,加快任务分析速度
*
* @param res 要拆分的总任务
* @param childTaskItemNum 每个子任务中的item数量
* @param childTaskLogic 每个子任务的处理逻辑
*/
public <E> void splitTaskExec(List<E> res, int childTaskItemNum, ObjIntConsumer<List<E>> childTaskLogic,
ObjIntConsumer<Exception> exceptionHandler, DoubleConsumer processHandler) {
List<List<E>> childTaskRes = CollectionUtil.split(res, childTaskItemNum);
splitTaskExec(childTaskRes, childTaskLogic, exceptionHandler, processHandler);
}
public static TaskUtil getInstance() {
if (instance == null) {
synchronized (TaskUtil.class) {
if (instance == null) {
instance = new TaskUtil();
}
}
}
return instance;
}
public ExecutorService getExecutor() {
if (Thread.currentThread().getName().startsWith(TUThreadFactory.namePrefix)
&& fixRoundExecutor.getTaskCount() > fixRoundExecutor.getMaximumPoolSize()) {
return cacheExecutor;
} else {
return fixRoundExecutor;
}
}
private static class TUThreadFactory implements ThreadFactory {
private final ThreadGroup group;
private final AtomicInteger threadNumber = new AtomicInteger(1);
private final static String namePrefix = "TUTheadPool";
private final String _namePrefix;
public TUThreadFactory(String poolName) {
SecurityManager s = System.getSecurityManager();
group = (s != null) ? s.getThreadGroup() : Thread.currentThread().getThreadGroup();
_namePrefix = namePrefix + "-" + poolName + "-thread-";
}
@Override
public Thread newThread(Runnable r) {
Thread t = new Thread(group, r, _namePrefix + threadNumber.getAndIncrement(), 0);
if (t.isDaemon()) {
t.setDaemon(false);
}
if (t.getPriority() != Thread.NORM_PRIORITY) {
t.setPriority(Thread.NORM_PRIORITY);
}
return t;
}
}
}