仿照jdk源码实现简化版的线程池

本文仿照Java的ThreadPoolExecutor,简化实现了一个线程池,仅保留核心exec()函数,便于理解线程池的工作原理。通过设定核心线程数、最大线程数、保持时间等参数,实现了任务的分配和线程的管理。

仿照java.util.concurrent包中的ThreadPoolExecutor实现,做了许多简化,只保留核心exec()函数,帮助理解源码,仅供参考。

ThreadPool

package com.jsj.jdk.threadpool;

import java.util.HashSet;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;

/**
 * 仿照jdk源码实现ThreadPool
 *
 * @author jsj
 * @date 2019-01-08
 */
public class ThreadPool {
    private final int corePoolSize;
    private final int maxPoolSize;
    private final long keepAliveTime;
    private ThreadFactory threadFactory;
    /**
     * 需要执行的任务队列
     */
    private BlockingQueue<Runnable> taskQueue;

    /**
     * 记录线程池中的Worker对象,其中Worker对象是线程的封装,加锁以后才能执行添加或删除操作
     */
    private final HashSet<Worker> workers = new HashSet<>();
    private final AtomicInteger threadCount = new AtomicInteger(0);
    private final ReentrantLock mainLock = new ReentrantLock();

    public ThreadPool(int corePoolSize, int maxPoolSize, long keepAliveTime, TimeUnit unit, ThreadFactory threadFactory, BlockingQueue<Runnable> taskQueue) {
        this.corePoolSize = corePoolSize;
        this.maxPoolSize = maxPoolSize;
        this.keepAliveTime = unit.toNanos(keepAliveTime);
        this.threadFactory = threadFactory;
        this.taskQueue = taskQueue;
    }

    public void exec(Runnable command) {
        if (command == null) throw new NullPointerException();
        //若当前线程数小于核心池数量,直接创建线程执行
        if (threadCount.get() < corePoolSize) {
            if (addWorker(command, true)) {
                return;
            }
        }
        //当前线程数大于corePoolSize,则加入taskQueue,若加入失败则再次创建线程执行该task
        if (!taskQueue.offer(command)) {
            System.out.println("taskQueue加入失败.." + command.toString() + ", 准备再次新建线程执行task");
            if (!addWorker(command, false)) {
                reject(command);
            }
        }
    }

    /**
     * 添加新线程
     *
     * @param firstTask 是否执行任务
     * @param core      是否核心池
     * @return
     */
    private boolean addWorker(Runnable firstTask, boolean core) {
        //检查线程数约束并使count++;
        retry:
        for (; ; ) {
            int size = core ? corePoolSize : maxPoolSize;
            for (; ; ) {
                if (threadCount.get() >= size) {
                    return false;
                }
                int c = threadCount.get();
                if (threadCount.compareAndSet(c, c + 1)) {
                    break retry;
                }
            }
        }
        boolean workerStarted = false;
        boolean workerAdded = false;
        //新建线程并封装成Worker对象
        Worker w = new Worker(firstTask);
        final Thread t = w.thread;
        final ReentrantLock lock = this.mainLock;
        lock.lock();
        try {
            //添加该线程
            workers.add(w);
            // precheck that t is startable
            if (t.isAlive()) throw new IllegalThreadStateException();
            workerAdded = true;
        } finally {
            lock.unlock();
        }
        if (workerAdded) {
            //线程启动
            t.start();
            workerStarted = true;
        }
        return workerStarted;
    }

    /**
     * 线程获取任务并执行
     *
     * @param w
     */
    final void runWorker(Worker w) {
        Runnable task = w.firstTask;
        w.firstTask = null;
        while (task != null || (task = getTask()) != null) {
            try {
                task.run();
            } finally {
                task = null;
            }
        }
        //获取任务为null,则移除该线程
        processWorkerExit(w);
    }

    /**
     * 移除线程
     *
     * @param w
     */
    private void processWorkerExit(Worker w) {
        final ReentrantLock mainLock = this.mainLock;
        mainLock.lock();
        try {
            workers.remove(w);
        } finally {
            mainLock.unlock();
        }
    }

    /**
     * 获取任务
     *
     * @return
     */
    private Runnable getTask() {
        // Did the last poll() time out?
        boolean timedOut = false;
        //死循环
        for (; ; ) {
            int wc = threadCount.get();

            // 线程获取任务是否需要计时
            boolean timed = wc > corePoolSize;
            // 若线程获取task超时||当前线程数大于maxPoolSize||workQueue为空? 线程数减1,返回null
            if ((wc > maxPoolSize || (timed && timedOut)) && (wc > 1 || taskQueue.isEmpty())) {
                wc = threadCount.get();
                if (threadCount.compareAndSet(wc, wc - 1)) return null;
                continue;
            }
            try {
                // 大于核心池数量限制的线程需要计时获取task
                Runnable r = timed ? taskQueue.poll(keepAliveTime, TimeUnit.NANOSECONDS) : taskQueue.take();
                //task获取并返回
                if (r != null) return r;
                //task获取超时
                timedOut = true;
            } catch (InterruptedException retry) {
                //线程中断
                timedOut = false;
            }
        }
    }

    /**
     * 拒绝任务
     *
     * @param command
     * @throws RuntimeException
     */
    private void reject(Runnable command) throws RuntimeException {
        System.out.println("拒绝执行任务:" + command.toString());
    }

    /**
     * 执行线程的封装类
     */
    private class Worker implements Runnable {
        Runnable firstTask;
        Thread thread;

        public Worker(Runnable firstTask) {
            this.firstTask = firstTask;
            this.thread = threadFactory.newThread(this);
        }

        @Override
        public void run() {
            runWorker(this);
        }
    }

    /**
     * 简单测试
     *
     * @param args
     */
    public static void main(String[] args) {
        ThreadPool threadPool = new ThreadPool(2, 5, 3000, TimeUnit.MILLISECONDS,
                new NamedThreadFactory(), new ArrayBlockingQueue<>(5));
        for (int i = 0; i < 20; i++) {
            threadPool.exec(new Runnable() {
                @Override
                public void run() {
                    for (int i = 0; i < 1; i++) {
                        System.out.println("Task: " + this.toString() + "is working!--------Executed by " + Thread.currentThread().getName());
                    }
                }
            });
        }
    }
}

NamedThreadFactory

/**
 * Thread factory to name the thread purposely
 *
 * @author jiangping
 * @version $Id: NamedThreadFactory.java, v 0.1 Sept 5, 2016 10:17:10 PM tao Exp $
 */
public class NamedThreadFactory implements ThreadFactory {

    private static final AtomicInteger poolNumber   = new AtomicInteger(1);
    private final AtomicInteger        threadNumber = new AtomicInteger(1);
    private final ThreadGroup          group;
    private final String               namePrefix;
    private final boolean              isDaemon;

    public NamedThreadFactory() {
        this("ThreadPool");
    }

    public NamedThreadFactory(String name) {
        this(name, false);
    }

    public NamedThreadFactory(String preffix, boolean daemon) {
        SecurityManager s = System.getSecurityManager();
        group = (s != null) ? s.getThreadGroup() : Thread.currentThread().getThreadGroup();
        namePrefix = preffix + "-" + poolNumber.getAndIncrement() + "-thread-";
        isDaemon = daemon;
    }

    /**
     * Create a thread.
     *
     * @see ThreadFactory#newThread(Runnable)
     */
    @Override
    public Thread newThread(Runnable r) {
        Thread t = new Thread(group, r, namePrefix + threadNumber.getAndIncrement(), 0);
        t.setDaemon(isDaemon);
        if (t.getPriority() != Thread.NORM_PRIORITY) {
            t.setPriority(Thread.NORM_PRIORITY);
        }
        return t;
    }

}

来自个人练习:My-Practice

GitHub:JiangJiangjungle

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值