手写一个 JAVA 线程池

  池化是我们在实际生产中经常用到的一种思想,通过一个 “池” 把资源统一的管理起来。可以达到对资源的合理管理、重复利用、减少资源创建/销毁的开销等目的。

  常见的比如常量池、连接池、线程池,今天我们手撸一个线程池。

  抛开语言特性,线程池无非是维护一堆线程阻塞等待任务的到来,并由主线程对任务线程的数量进行动态控制的组件。做到线程资源的复用及统一管理,同时避免大量的线程创建销毁的开销,并控制总的线程数量保证系统安全。

  功能要点

  1. 维护一个线程集合,执行外部提交的任务。

  2. 管理线程的集合,没有任务时维护一定数量的线程等待任务的到来。在运行过程中动态的根据任务的多少创建或销毁池中的线程,合理的分配资源。

  3. 提交成功的任务由线程池保证一定会被执行,同时对于提交失败的任务需要向调用方进行反馈,比如抛出异常由调用方进行处理。

  4. 对外提供安全的关闭方法,在保证任务队列中的任务都会被处理的前提下可以正确的回收池内所有的线程。

  实现思路要点

  1. 采用阻塞队列存储任务,在没有任务时任务线程阻塞在队列的 getTask 方法,并在 putTask 时唤醒等待在该队列的任务线程。

  2. 存储队列对外提供多种拒绝策略,由调用方在创建线程池时选择。调用方完全掌握任务被拒绝时的线程池的处理方式,从而做出正确的处理。

  3. 任务线程有自己的关闭标识,同时应正确的响应 interrupt 信号,以便在阻塞在任务队列上时可以被正确的关闭。

  4. 主线程轮询的检查任务数与活动线程数,动态的增加或减少活动线程,增加或减少线程数的算法会极大的影响整个线程池的性能,应根据任务的提交情况合理设计。

  5. 在使用同步锁的情况下,阻塞队列与主线程的同步锁应合理设计。因为同步锁为持有等待且不可抢占的,在我们不能完全掌握每个线程获取锁的顺序时,容易发生死锁的情况。

  本次实现的不足

  1. 锁的粒度不够细,为了避免死锁,主线程与阻塞队列共用一把同步锁,影响性能。

  2. 控制线程数的方式不够细致:任务数超过活动线程数的两倍时开启所有线程;无任务时才会缩减线程。

  3. 阻塞队列由自己实现,简单粗暴的使用 synchronized 控制互斥关系,效率不够高。

  4. 还可以在一些信号量上使用 CAS 或直接使用非阻塞队列提升性能。

  5. 没有定义对任务在执行过程中发生异常的处理。实际情况下应该记录没有执行成功的任务并记录下来,比如序列化任务对象记录到数据库,以便进行人工补偿。但因为本次实现任务对象直接使用 Runnable 的简单实现类,序列化没有意义,所以没有考虑这点。 

  测试结果

  先上测试代码:

public static void main(String[] args) {
        BasicThreadPool threadPool = new BasicThreadPool();
        for (int i = 0; i <= 100; i++) {
            final int num = i;
            threadPool.excute(
                    () -> {
                        System.out.println(Thread.currentThread().getName() + " : i am running to deal with task of " + num);
                        try {
                            Thread.sleep(1000);
                        } catch (InterruptedException e) {
                            return;
                        }
                    }
            );

        }
    }

  任务 sleep 一段时间是防止任务处理速度过快,无法测试动态增加或减少线程数的功能。测试结果:

  首先在提交任务时,我设置了初始线程数为 1 ,由于任务会 sleep 一段时间,造成了大量任务积压在了缓冲区,因此线程池马力全开,将线程数增加到了最大(add部分):

   积压任务处理完后,线程池缩减线程数(remove部分),只留下了核心线程,我设置的核心线程数为 3 ,为 0,1,2 号线程:

   然后我们改一下测试代码,增加关闭线程池的操作:

    public static void main(String[] args) {
        BasicThreadPool threadPool = new BasicThreadPool();
        for (int i = 0; i <= 100; i++) {
            final int num = i;
            threadPool.excute(
                    () -> {
                        System.out.println(Thread.currentThread().getName() + " : i am running to deal with task of " + num);
                        try {
                            Thread.sleep(1000);
                        } catch (InterruptedException e) {
                            return;
                        }
                    }
            );
            if (i == 90) {
                threadPool.shutdown();
            }
        }
    }

  在第 90 个任务提交时关闭线程池,之后的任务线程池拒绝接收,将抛出异常给调用者:

    线程池处于关闭状态时,依然会保证已经在队列中的任务会被执行完毕:

   任务队列中挤压的任务全部处理完毕后,会终止所有任务线程,包括阻塞状态的线程:

   剩下的全在代码里了。

   拒绝策略:

package theadPool;

/**
 * @Author Nxy
 * @Date 2020/3/14 14:26
 * @Description 拒绝策略
 */
public interface DenyPolicy {

    public void reject(Runnable task, ThreadPool pool);

    /**
     * @Author Nxy
     * @Date 2020/3/14 14:27
     * @Description 任务队列溢出异常
     */
    class OutOfRunnableQueueException extends RuntimeException {
        OutOfRunnableQueueException(String msg) {
            super(msg);
        }
    }

    /**
     * @Author Nxy
     * @Date 2020/3/14 14:30
     * @Description 直接丢弃任务
     */
    class DiscardDenyPolicy implements DenyPolicy {
        @Override
        public void reject(Runnable task, ThreadPool pool) {
            //do nothing
        }
    }

    /**
     * @Author Nxy
     * @Date 2020/3/14 14:42
     * @Description 抛出 任务队列溢出异常
     */
    class throwExceptionDenyPolicy implements DenyPolicy {
        @Override
        public void reject(Runnable task, ThreadPool pool) {
            throw new OutOfRunnableQueueException("task queue is full!");
        }
    }

    /**
     * @Author Nxy
     * @Date 2020/3/14 14:42
     * @Description 由提交线程直接执行任务
     */
    class RunnerDenyPolicy implements DenyPolicy {
        @Override
        public void reject(Runnable task, ThreadPool pool) {
            task.run();
        }
    }
}

  任务队列:

package theadPool;

/**
 * @Author Nxy
 * @Date 2020/3/14 14:23
 * @Description 任务队列
 */
public interface TaskQueue {

    //新任务追加到队列结尾
    void putTask(Runnable runnable);

    //获取任务,该方法是阻塞的,应当向上抛出 InterruptException 使调用方做出阻塞期间对 interrupt 信号的响应
    Runnable getTask() throws InterruptedException;

    //获取当前任务数
    int getSize();
}
package theadPool;

import java.util.LinkedList;

public class LinkedTaskQueue implements TaskQueue {
    //队列最大长度
    private final int maxSize;
    //任务达到最大数后的拒绝策略
    private final DenyPolicy denyPolicy;
    //任务队列,链表实现
    private final LinkedList<Runnable> queue = new LinkedList<Runnable>();

    private final ThreadPool threadPool;

    LinkedTaskQueue(ThreadPool threadPool, int maxSize, DenyPolicy denyPolicy) {
        this.maxSize = maxSize;
        this.denyPolicy = denyPolicy;
        this.threadPool = threadPool;
    }

    /**
     * @Author Nxy
     * @Date 2020/3/14 21:07
     * @Description 以下互斥区域使用 threadPool 对象锁,因为在 threadPool 中存在对这些方法的调用,
     * 在调用情况复杂,不好判断调用次序的情况下(并且存在持有等待、不可抢占的特性)用两把锁容易造成死锁
     */
    @Override
    public void putTask(Runnable runnable) {
        synchronized (threadPool) {
            //超出最大任务数或者线程池已关闭,采取拒绝策略
            if (getSize() >= maxSize || threadPool.isShutDown()) {
                denyPolicy.reject(runnable, threadPool);
                return;
            }
            //任务追加到任务队列结尾
            queue.addLast(runnable);
            //唤醒等待在任务队列的工作线程
            threadPool.notify();
        }
    }

    @Override
    public Runnable getTask() throws InterruptedException {
        Runnable returnRunnable;
        synchronized (threadPool) {
            if (queue.isEmpty()) {
                System.out.println(Thread.currentThread().getName() + " 等待在缓冲区");
                threadPool.wait();
            }
            //先进先出队列
            returnRunnable = queue.removeFirst();
        }
        return returnRunnable;
    }

    @Override
    public int getSize() {
        synchronized (queue) {
            return queue.size();
        }
    }
}

  任务线程:

package theadPool;

/**
 * @Author Nxy
 * @Date 2020/3/14 14:47
 * @Description 线程池中的线程
 */
public class PoolTask extends Thread implements Runnable {
    //任务队列
    private final TaskQueue queue;
    //当前线程运行标志位
    private volatile boolean isRunning = true;

    //传入队列
    public PoolTask(TaskQueue queue) {
        this.queue = queue;
    }

    /**
     * @Author Nxy
     * @Date 2020/3/14 14:52
     * @Description 除通过 isRunning 标志位可关闭该线程外,interrupt 信号也可关闭该线程
     */
    @Override
    public void run() {
        while (isRunning && !(this.isInterrupted())) {
            try {
                //获取任务并执行
                Runnable runnable = queue.getTask();
                runnable.run();
            } catch (InterruptedException e0) {
                System.out.println(Thread.currentThread().getName() + " 接收到 interrupt 信号,终止执行");
                return;
            } catch (Exception e1) {
//                e1.printStackTrace();
                return;
            }
        }
        System.out.println(Thread.currentThread().getName() + " 终止执行");
    }

    //安全的关闭该线程
    public void shutdown() {
        this.isRunning = false;
        this.interrupt();
    }
}

  主类:

package theadPool;

import java.util.LinkedList;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * @Author Nxy
 * @Date 2020/3/14 16:55
 * @Description 线程池
 */
public class BasicThreadPool extends Thread implements ThreadPool {
    //初始线程数
    private final int initSize;
    //最大线程数
    private final int maxSize;
    //最大任务数
    private final int maxTaskSzie;
    //核心线程数
    private final int coreSize;
    //当前活跃线程数
    private int activeCount;
    //任务队列
    private final TaskQueue taskQueue;
    //关闭控制位
    private volatile boolean isShutDown = false;
    //主线程扫描线程池状态时间间隔
    private final long keepAliveTime;
    //睡眠工具类
    private final TimeUnit timeUtil;

    //线程队列
    private final LinkedList<TaskThread> taskPool = new LinkedList<TaskThread>();

    //任务拒绝策略
    private final DenyPolicy denyPolicy;

    //默认参数
    private static final int defaultInitSize = 1;
    private static final int defaultMaxSize = 8;
    private static final int defaultCoreSize = 3;
    private static final DenyPolicy defaultDenyPolicy = new DenyPolicy.throwExceptionDenyPolicy();
    private static final int defaultMaxTaskSize = 1000;
    private static final TimeUnit defaultTimeUtil = TimeUnit.MILLISECONDS;
    private static final int defaultKeepAlive = 1;

    //默认参数构造
    public BasicThreadPool() {
        this(defaultInitSize, defaultMaxSize, defaultCoreSize, defaultKeepAlive, defaultDenyPolicy, defaultMaxTaskSize, defaultTimeUtil);
    }

    //自定义拒绝策略构造
    public BasicThreadPool(DenyPolicy denyPolicy) {
        this(defaultInitSize, defaultMaxSize, defaultCoreSize, defaultKeepAlive, denyPolicy, defaultMaxTaskSize, defaultTimeUtil);
    }

    //自定义参数构造
    public BasicThreadPool(int initSize, int maxSize, int coreSize, long keepAliveTime, DenyPolicy denyPolicy, int maxTaskSzie, TimeUnit timeUtil) {
        this.initSize = initSize;
        this.maxSize = maxSize;
        this.coreSize = coreSize;
        this.denyPolicy = denyPolicy;
        this.keepAliveTime = keepAliveTime;
        this.maxTaskSzie = maxTaskSzie;
        this.timeUtil = timeUtil;
        taskQueue = new LinkedTaskQueue(this, maxTaskSzie, denyPolicy);
        this.init();
    }

    /**
     * @Author Nxy
     * @Date 2020/3/14 16:03
     * @Description 静态线程工厂
     */
    static class ThreadFactory {
        private static final AtomicInteger groupCounter = new AtomicInteger(1);
        private static final AtomicInteger counter = new AtomicInteger(0);
        private static final ThreadGroup group = new ThreadGroup("BasicThreadPool group : " + groupCounter.getAndIncrement());

        public static Thread createThread(Runnable runnable) {
            return new Thread(group, runnable, "BasicThreadPool group " + groupCounter.get() + " : " + counter.getAndIncrement());
        }
    }

    /**
     * @Author Nxy
     * @Date 2020/3/14 15:48
     * @Description 线程池内线程,poolTask 与 thread 的结合,poolTask 携带run,thread 携带池中参数
     */
    class TaskThread {
        final PoolTask poolTask;
        final Thread thread;

        TaskThread(PoolTask poolTask, Thread thread) {
            this.thread = thread;
            this.poolTask = poolTask;
        }

        public void shutdown() {
            poolTask.shutdown();
            /**
             *   @Author Nxy
             *   @Date 2020/3/14 22:04
             *   @Description 很重要,仅仅给 poolTask 发送interrupt,所在线程并不会收到信号
             *   线程与 Thread 或 runnable 对象是绑定的,但我们用 runnable 提交任务到 Thread 时,线程绑定的
             *   是 Thread 对象,因此要通过 Thread 对象发送 interrupt 信号
             */
            thread.interrupt();
        }
    }

    /**
     * @Author Nxy
     * @Date 2020/3/14 16:17
     * @Description 新增一个活动线程
     */
    private void addThread() {
        PoolTask poolTask = new PoolTask(taskQueue);
        Thread thread = BasicThreadPool.ThreadFactory.createThread(poolTask);
        TaskThread taskThread = new TaskThread(poolTask, thread);
        synchronized (this) {
            activeCount++;
            taskPool.addLast(taskThread);
        }
        thread.start();
    }

    /**
     * @Author Nxy
     * @Date 2020/3/14 16:35
     * @Description 关闭一个活动线程
     */
    private void removeThread() {
        synchronized (this) {
            TaskThread taskThread = taskPool.removeLast();
            taskThread.shutdown();
            activeCount--;
        }
    }

    /**
     * @Author Nxy
     * @Date 2020/3/14 16:22
     * @Description 不断检查线程数,动态调整活动线程数量
     */
    @Override
    public void run() {
        while (!this.isShutDown && !Thread.currentThread().isInterrupted()) {
            try {
                timeUtil.sleep(10);
            } catch (InterruptedException e) {
                //休眠时间响应 interrupt 信号
                this.isShutDown = true;
                break;
            }
            //有任务且线程数小于核心线程数
            synchronized (this) {
                if (this.isShutDown) {
                    //DCL 检查
                    break;
                }
                //任务数超过当前线程数的两倍,线程池马力全开
                if (taskQueue.getSize() >= activeCount * 2 && activeCount < maxSize) {
                    int beforeActive = activeCount;
                    for (int i = activeCount; i < maxSize; i++) {
                        addThread();
                    }
                    System.out.println("add : actice->" + beforeActive + " ---> " + activeCount);
                    continue;
                }
                //任务数不为0且活动线程数小于核心线程数,新增线程数过核心数线
                if (taskQueue.getSize() > 0 && activeCount < coreSize) {
                    for (int i = activeCount; i < coreSize; i++) {
                        addThread();
                    }
                    System.out.println("add : actice->" + activeCount);
                    continue;
                }
                //无任务且线程数大于核心线程数,关闭线程,仅留下核心线程数的线程
                if (taskQueue.getSize() == 0 && activeCount > coreSize) {
                    for (int i = coreSize; i < activeCount; i++) {
                        removeThread();
                        System.out.println("remove : actice->" + activeCount);
                    }
                }
            }
        }
        System.out.println("******线程池主线程关闭******");
    }

    private void init() {
        this.start();
        for (int i = 0; i < initSize; i++) {
            addThread();
        }
    }

    /**
     * @Author Nxy
     * @Date 2020/3/14 16:53
     * @Description 向线程池中提交新任务
     */
    @Override
    public void excute(Runnable runnable) {
        if (this.isShutDown || Thread.currentThread().isInterrupted()) {
            throw new RuntimeException("threadPool is closed!");
        }
        taskQueue.putTask(runnable);
    }

    /**
     * @Author Nxy
     * @Date 2020/3/14 16:53
     * @Description 关闭线程池
     */
    @Override
    public void shutdown() {
        synchronized (this) {
            this.isShutDown = true;
        }
        System.out.println("正在等待处理任务队列剩余任务");
        while (taskQueue.getSize() != 0) {
            //等待任务队列中的任务全部被处理完
        }
        System.out.println("开始关闭线程池中线程");
        taskPool.forEach(threadTask -> {
                    threadTask.shutdown();
                }
        );
        this.interrupt();
        System.out.println("******线程池中所有任务线程已关闭******");
    }

    @Override
    public int getInitSize() {
        if (this.isShutDown || Thread.currentThread().isInterrupted()) {
            throw new RuntimeException("threadPool is closed!");
        }
        return this.initSize;
    }

    @Override
    public int getMaxSize() {
        if (this.isShutDown || Thread.currentThread().isInterrupted()) {
            throw new RuntimeException("threadPool is closed!");
        }
        return this.maxSize;
    }

    @Override
    public int getQueSize() {
        if (this.isShutDown || Thread.currentThread().isInterrupted()) {
            throw new RuntimeException("threadPool is closed!");
        }
        return this.taskQueue.getSize();
    }

    @Override
    public int getActiveCount() {
        if (this.isShutDown || Thread.currentThread().isInterrupted()) {
            throw new RuntimeException("threadPool is closed!");
        }
        synchronized (this) {
            return this.activeCount;
        }
    }

    @Override
    public boolean isShutDown() {
        if (this.isShutDown || Thread.currentThread().isInterrupted()) {
            throw new RuntimeException("threadPool is closed!");
        }
        synchronized (this) {
            return this.isShutDown;
        }
    }
}

 

posted @ 2020-03-14 23:33  牛有肉  阅读(3655)  评论(0编辑  收藏  举报