最近写测试代码的时候,需要判断提交的线程池里面的任务是否都已经全部执行完成。在网上找到了一圈,发现
https://blog.youkuaiyun.com/flycp/article/details/106337294 这篇博客写到算是比较完整的。但是还是有几个问题。
目前常见方案的缺陷
CountDownLatch
1 对业务代码有侵入
2 需要提前知道任务的总数
Future.get()
Future.get()的方案没有问题,但是依然是由用户来自己判断任务是否完成。
我认为这个功能应该是线程池本身来提供的。
我的思路
package com.alibaba.common;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.*;
public class CustomThreadPool {
private static final Logger logger = LoggerFactory.getLogger(CustomThreadPool.class);
private final ScheduledExecutorService scheduler;
private final Map<String, CopyOnWriteArrayList<Future<?>>> taskGroups = new ConcurrentHashMap<>();
private final ThreadPoolExecutor executor;
private static final String DEFAULT_GROUP_ID = "DEFAULT_GROUP_ID";
public CustomThreadPool(int corePoolSize, int maximumPoolSize) {
this.scheduler = Executors.newSingleThreadScheduledExecutor();
// 安排每1秒执行一次cleanupCompletedTaskGroups
this.scheduler.scheduleAtFixedRate(this::cleanupCompletedTaskGroups, 5, 10, TimeUnit.SECONDS);
executor = new ThreadPoolExecutor(corePoolSize, maximumPoolSize, 3, TimeUnit.SECONDS
, new LinkedBlockingQueue<>(1024), Executors.defaultThreadFactory(), new ThreadPoolExecutor.AbortPolicy());
executor.allowCoreThreadTimeOut(true);
}
public void submitTask(Runnable task) {
submitTask(DEFAULT_GROUP_ID, task);
}
/**
* 提交单个 Runnable 任务到线程池,并将其与任务组ID关联。
*/
public void submitTask(String groupId, Runnable task) {
List<Future<?>> futures = taskGroups.computeIfAbsent(groupId, k -> new CopyOnWriteArrayList<>());
futures.add(executor.submit(task));
}
public <T> Future<T> submitTask(Callable<T> task) {
return submitTask(DEFAULT_GROUP_ID, task);
}
/**
* 提交单个 Callable 任务到线程池,并将其与任务组ID关联。
*/
public <T> Future<T> submitTask(String groupId, Callable<T> task) {
List<Future<?>> futures = taskGroups.computeIfAbsent(groupId, k -> new CopyOnWriteArrayList<>());
Future<T> future = executor.submit(task);
futures.add(future);
return future;
}
/**
* 等待直到当前实例中的所有任务组的所有任务都已完成。
*/
public void awaitCompletion(int timeOutSecond) throws InterruptedException, ExecutionException, TimeoutException {
logger.info("awaitCompletion all task finish");
for (List<Future<?>> futures : taskGroups.values()) {
for (Future<?> future : futures) {
try {
// get()会阻塞,直到对应的异步计算完成或超时
future.get(timeOutSecond, TimeUnit.SECONDS);
} catch (TimeoutException e) {
logger.warn("Task timed out after {} seconds", timeOutSecond);
throw e;
}
}
}
logger.info("all task finish");
}
/**
* 关闭线程池。
*/
public void shutdown() {
logger.info("send shutdown command");
executor.shutdown();
scheduler.shutdown();
try {
if (!executor.awaitTermination(8000, TimeUnit.MILLISECONDS)) {
logger.info(" executor thread pool shutdownNow");
executor.shutdownNow();
}
if (!scheduler.awaitTermination(8000, TimeUnit.MILLISECONDS)) {
logger.info(" scheduler thread pool shutdownNow");
scheduler.shutdownNow();
}
logger.info("thread pool shutdown succ");
} catch (InterruptedException e) {
executor.shutdownNow();
scheduler.shutdownNow();
Thread.currentThread().interrupt();
}
}
/**
* 移除已经完成的任务组,以便释放资源。
*/
private void cleanupCompletedTaskGroups() {
for (Map.Entry<String, CopyOnWriteArrayList<Future<?>>> entry : taskGroups.entrySet()) {
String groupId = entry.getKey();
List<Future<?>> futures = entry.getValue();
// 使用迭代器遍历并移除已完成的任务
Iterator<Future<?>> iterator = futures.iterator();
while (iterator.hasNext()) {
Future<?> future = iterator.next();
if (future.isDone()) {
try {
future.get(); // 获取结果,如果有的话
} catch (InterruptedException | ExecutionException ex) {
logger.warn("Error occurred while getting the result of a completed task in group {}: {}", groupId, ex.getMessage());
Thread.currentThread().interrupt(); // Restore interrupted status
}
iterator.remove(); // 移除已完成的任务
}
}
// 如果任务组中没有剩余任务,则移除该任务组
if (futures.isEmpty()) {
logger.info("All tasks in {}, have been finished", groupId);
taskGroups.remove(groupId);
}
}
}
// 示例用法
public static void main(String[] args) throws InterruptedException, ExecutionException {
CustomThreadPool customThreadPool = new CustomThreadPool(2, 3); // 创建一个包含4个线程的线程池
// 提交单个任务并指定任务组ID,使用lambda表达式
customThreadPool.submitTask("group1", () -> {
logger.info("Task 1-1 is running");
try {
Thread.sleep(new Random().nextInt(8000));
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
logger.info("Task 1-1 is finished");
}
);
customThreadPool.submitTask("group1", () -> {
logger.info("Task 1-2 is running");
try {
Thread.sleep(new Random().nextInt(8000));
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
logger.info("Task 1-2 is finished");
}
);
// 提交带有返回值的任务
Future<Integer> future = customThreadPool.submitTask("group2", () -> {
logger.info("Task 2-1 with result is running");
Thread.sleep(new Random().nextInt(8000));
logger.info("Task 2-1 is finished");
return 42; // 返回一个结果
});
// 提交带有返回值的任务
Future<Integer> future2 = customThreadPool.submitTask("group2", () -> {
logger.info("Task 2-2 with result is running");
Thread.sleep(new Random().nextInt(8000));
logger.info("Task 2-2 is finished");
return 43; // 返回一个结果
});
// 提交带有返回值的任务
Future<Integer> future3 = customThreadPool.submitTask("group3", () -> {
logger.info("Task 3-1 with result is running");
Thread.sleep(new Random().nextInt(8000));
logger.info("Task 3-1 is finished");
return 45; // 返回一个结果
});
// 等待直到所有任务组的任务完成
try {
customThreadPool.awaitCompletion(10);
} catch (TimeoutException e) {
customThreadPool.shutdown();
}
// 关闭线程池
customThreadPool.shutdown();
logger.info("All tasks are completed.");
}
}