线程池判断当前任务组是否都已经结束

最近写测试代码的时候,需要判断提交的线程池里面的任务是否都已经全部执行完成。在网上找到了一圈,发现
https://blog.youkuaiyun.com/flycp/article/details/106337294 这篇博客写到算是比较完整的。但是还是有几个问题。

目前常见方案的缺陷

CountDownLatch

1 对业务代码有侵入
2 需要提前知道任务的总数

Future.get()

Future.get()的方案没有问题,但是依然是由用户来自己判断任务是否完成。
我认为这个功能应该是线程池本身来提供的。

我的思路

package com.alibaba.common;


import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.*;

public class CustomThreadPool {
    private static final Logger logger = LoggerFactory.getLogger(CustomThreadPool.class);

    private final ScheduledExecutorService scheduler;
    private final Map<String, CopyOnWriteArrayList<Future<?>>> taskGroups = new ConcurrentHashMap<>();

    private final ThreadPoolExecutor executor;

    private static final String DEFAULT_GROUP_ID = "DEFAULT_GROUP_ID";


    public CustomThreadPool(int corePoolSize, int maximumPoolSize) {
        this.scheduler = Executors.newSingleThreadScheduledExecutor();

        // 安排每1秒执行一次cleanupCompletedTaskGroups
        this.scheduler.scheduleAtFixedRate(this::cleanupCompletedTaskGroups, 5, 10, TimeUnit.SECONDS);

        executor = new ThreadPoolExecutor(corePoolSize, maximumPoolSize, 3, TimeUnit.SECONDS
                , new LinkedBlockingQueue<>(1024), Executors.defaultThreadFactory(), new ThreadPoolExecutor.AbortPolicy());
        executor.allowCoreThreadTimeOut(true);
    }


    public void submitTask(Runnable task) {
        submitTask(DEFAULT_GROUP_ID, task);
    }

    /**
     * 提交单个 Runnable 任务到线程池,并将其与任务组ID关联。
     */
    public void submitTask(String groupId, Runnable task) {
        List<Future<?>> futures = taskGroups.computeIfAbsent(groupId, k -> new CopyOnWriteArrayList<>());
        futures.add(executor.submit(task));
    }

    public <T> Future<T> submitTask(Callable<T> task) {
        return submitTask(DEFAULT_GROUP_ID, task);
    }


    /**
     * 提交单个 Callable 任务到线程池,并将其与任务组ID关联。
     */
    public <T> Future<T> submitTask(String groupId, Callable<T> task) {
        List<Future<?>> futures = taskGroups.computeIfAbsent(groupId, k -> new CopyOnWriteArrayList<>());
        Future<T> future = executor.submit(task);
        futures.add(future);
        return future;
    }


    /**
     * 等待直到当前实例中的所有任务组的所有任务都已完成。
     */
    public void awaitCompletion(int timeOutSecond) throws InterruptedException, ExecutionException, TimeoutException {
        logger.info("awaitCompletion all task finish");
        for (List<Future<?>> futures : taskGroups.values()) {
            for (Future<?> future : futures) {
                try {
                    // get()会阻塞,直到对应的异步计算完成或超时
                    future.get(timeOutSecond, TimeUnit.SECONDS);
                } catch (TimeoutException e) {
                    logger.warn("Task timed out after {} seconds", timeOutSecond);
                    throw e;
                }
            }
        }
        logger.info("all task finish");
    }

    /**
     * 关闭线程池。
     */
    public void shutdown() {
        logger.info("send shutdown command");
        executor.shutdown();
        scheduler.shutdown();
        try {
            if (!executor.awaitTermination(8000, TimeUnit.MILLISECONDS)) {
                logger.info(" executor thread pool shutdownNow");
                executor.shutdownNow();
            }
            if (!scheduler.awaitTermination(8000, TimeUnit.MILLISECONDS)) {
                logger.info(" scheduler thread pool shutdownNow");
                scheduler.shutdownNow();
            }

            logger.info("thread pool shutdown succ");
        } catch (InterruptedException e) {
            executor.shutdownNow();
            scheduler.shutdownNow();
            Thread.currentThread().interrupt();
        }
    }

    /**
     * 移除已经完成的任务组,以便释放资源。
     */
    private void cleanupCompletedTaskGroups() {
        for (Map.Entry<String, CopyOnWriteArrayList<Future<?>>> entry : taskGroups.entrySet()) {
            String groupId = entry.getKey();
            List<Future<?>> futures = entry.getValue();

            // 使用迭代器遍历并移除已完成的任务
            Iterator<Future<?>> iterator = futures.iterator();
            while (iterator.hasNext()) {
                Future<?> future = iterator.next();
                if (future.isDone()) {
                    try {
                        future.get(); // 获取结果,如果有的话
                    } catch (InterruptedException | ExecutionException ex) {
                        logger.warn("Error occurred while getting the result of a completed task in group {}: {}", groupId, ex.getMessage());
                        Thread.currentThread().interrupt(); // Restore interrupted status
                    }
                    iterator.remove(); // 移除已完成的任务
                }
            }

            // 如果任务组中没有剩余任务,则移除该任务组
            if (futures.isEmpty()) {
                logger.info("All tasks in {}, have been finished", groupId);
                taskGroups.remove(groupId);
            }
        }
    }

    // 示例用法
    public static void main(String[] args) throws InterruptedException, ExecutionException {
        CustomThreadPool customThreadPool = new CustomThreadPool(2, 3); // 创建一个包含4个线程的线程池

        // 提交单个任务并指定任务组ID,使用lambda表达式
        customThreadPool.submitTask("group1", () -> {
                    logger.info("Task 1-1 is running");
                    try {
                        Thread.sleep(new Random().nextInt(8000));
                    } catch (InterruptedException e) {
                        throw new RuntimeException(e);
                    }
                    logger.info("Task 1-1 is finished");
                }
        );

        customThreadPool.submitTask("group1", () -> {
                    logger.info("Task 1-2 is running");
                    try {
                        Thread.sleep(new Random().nextInt(8000));
                    } catch (InterruptedException e) {
                        throw new RuntimeException(e);
                    }
                    logger.info("Task 1-2 is finished");

                }
        );

        // 提交带有返回值的任务
        Future<Integer> future = customThreadPool.submitTask("group2", () -> {
            logger.info("Task 2-1 with result is running");
            Thread.sleep(new Random().nextInt(8000));
            logger.info("Task 2-1 is finished");
            return 42; // 返回一个结果
        });

        // 提交带有返回值的任务
        Future<Integer> future2 = customThreadPool.submitTask("group2", () -> {
            logger.info("Task 2-2 with result is running");
            Thread.sleep(new Random().nextInt(8000));
            logger.info("Task 2-2 is finished");
            return 43; // 返回一个结果
        });

        // 提交带有返回值的任务
        Future<Integer> future3 = customThreadPool.submitTask("group3", () -> {
            logger.info("Task 3-1 with result is running");
            Thread.sleep(new Random().nextInt(8000));
            logger.info("Task 3-1 is finished");
            return 45; // 返回一个结果
        });

        // 等待直到所有任务组的任务完成
        try {
            customThreadPool.awaitCompletion(10);
        } catch (TimeoutException e) {
            customThreadPool.shutdown();
        }

        // 关闭线程池
        customThreadPool.shutdown();

        logger.info("All tasks are completed.");
    }
}

参考资料

https://blog.youkuaiyun.com/flycp/article/details/106337294

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值