Java高并发之CountDownLatch源码分析

概述

CountDownLatch 允许一个或多个线程等待直到在其他线程中执行的一组操作完成的同步辅助。简单来说,就是 CountDownLatch 内部维护了一个计数器,每个线程完成自己的操作之后都会将计数器减一,然后会在计数器的值变为 0 之前一直阻塞,直到计数器的值变为 0.

简单使用

这个例子主要演示了,如何利用 CountDownLatch 去协调多个线程同时开始运行。这个时候的 CountDownLatch 中的计数器的现实含义是等待创建的线程个数,每个线程在开始任务之前都会调用 await() 方法阻塞,直到所有线程都创建好,每当一个线程创建好后,都会提交调用 countDown() 方法将计数器的值减一 (代表待创建的线程数减一)。

public static void main(String[] args) {
    Test countDownLatchTest=new Test();
    countDownLatchTest.runThread();
}
//计数器为10,代表有10个线程等待创建
CountDownLatch countDownLatch=new CountDownLatch(10);

/**
 * 创建一个线程
 * @return
 */
private Thread createThread(int i){
    Thread thread=new Thread(new Runnable() {
        @Override
        public void run() {
            try {
                //在此等待,直到计数器变为0
                countDownLatch.await();
                System.out.println("thread"+Thread.currentThread().getName()+"准备完毕"+System.currentTimeMillis());
            }catch (InterruptedException e){
                e.printStackTrace();
            }

        }
    });
    thread.setName("thread-"+i);
    return  thread;
}

public void runThread(){
    ExecutorService executorService= Executors.newFixedThreadPool(10);

    try {
        for(int i=0;i<10;i++){
            Thread.sleep(100);
            executorService.submit(createThread(i));
            //一个线程创建好了,待创建的线程数减一
            countDownLatch.countDown();
        }
    }catch (InterruptedException e){
        e.printStackTrace();
    }

}

下面我们就以这个例子,来解释源码:

源码分析

继承体系

从锁的分类上来讲,CountDownLatch 其实是一个” 共享锁 “。还有一个需要注意的是 CountDownLath 是响应中断的,如果线程在对锁进行操作的期间发生了中断,会直接抛出 InterruptedException。

源码分析

计数器的本质是什么?

刚才我们也提到了,CountDownLatch 中一个非常重要的东西就是计数器。那么我们首先需要分析的就是源码中哪个部分充当了计数器的角色。
我们通过构造方法来查看:
我们的代码CountDownLatch countDownLatch=new CountDownLatch(10);背后实际上是调用了下面这个方法:

public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}

而这个 Sync 的实例化又做了什么工作呢?

Sync(int count) {
    setState(count); //就是修改了AQS中的state值
}

现在已经解决了我们的第一个问题,实际上 AQS 中的 state 充当了计数器。

await 方法

  1. await 方法实际上是调用了 sync 的一个方法
public void await() throws InterruptedException {
     sync.acquireSharedInterruptibly(1);
}
  1. sync 的void acquireSharedInterruptibly(int arg)的实现如下
public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
        //如果线程中断了,则抛异常。
        //证明了之前所说的CountDownLatch是会响应中断的
            throw new InterruptedException();
        if (tryAcquireShared(arg) < 0)
            doAcquireSharedInterruptibly(arg);
}
  1. 如果没有中断,就会调用tryAcquireShared(arg)
    它的实现非常的简单,如果 state 为 0,就返回 1,否则返回 - 1
protected int tryAcquireShared(int acquires) {
        return (getState() == 0) ? 1 : -1;
}
  1. 如果 state 不为 0,就会返回 - 1,if 条件成立,就会调用doAcquireSharedInterruptibly(arg)
    这个方法的实现,稍微复杂一点,但这个方法也不陌生了,它的功能就是把该线程加入等待队列中并阻塞,但是在入队之后,不一定会立即 park 阻塞,它会判断自己是否是第二个节点,如果是就会再次尝试获取。
private void doAcquireSharedInterruptibly(int arg)
        throws InterruptedException {
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            for (;;) {
                final Node p = node.predecessor(); //获取当前节点的前驱节点
                if (p == head) {//前一个节点是头节点
                    int r = tryAcquireShared(arg); //去看一看state是否为0,步骤3分析过
                    if (r >= 0) {
                    //如果state目前为0,就出队
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        failed = false;
                        return;
                    }
                }
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    //进入阻塞队列阻塞,如果发生中断,则抛异常
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
}

CountDownLatch 的 await 方法比其它几个锁的实现简单得多。不过需要注意的一点就是 CountDownLatch 是会响应中断的,这一点在源码中也有多处体现。

countDown 方法

  1. countDown 方法实际上是调用 sync 中的一个方法
public void countDown() {
        sync.releaseShared(1);
}
  1. boolean releaseShared(int arg)的具体实现如下:
public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }
        return false;
}
  1. tryReleaseShared(arg)方法的具体实现如下:
protected boolean tryReleaseShared(int releases) {
               // Decrement count; signal when transition to zero
               for (;;) {//自旋
                   int c = getState();
                   if (c == 0)//计数器已经都是0了,当然会释放失败咯
                       return false;
                   int nextc = c-1;//释放后,计数器减一
                   if (compareAndSetState(c, nextc))//CAS修改计数器
                       return nextc == 0;
               }
}

这个方法就是去尝试直接修改 state 的值。如果 state 的修改成功,且修改后的 state 值为 0,就会返回 true。就会执行doReleaseShared();方法。

  1. doReleaseShared();的实现如下,它的作用就是 state 为 0 的时候,去唤醒等待队列中的线程。
private void doReleaseShared() {
        /*
         * Ensure that a release propagates, even if there are other
         * in-progress acquires/releases.  This proceeds in the usual
         * way of trying to unparkSuccessor of head if it needs
         * signal. But if it does not, status is set to PROPAGATE to
         * ensure that upon release, propagation continues.
         * Additionally, we must loop in case a new node is added
         * while we are doing this. Also, unlike other uses of
         * unparkSuccessor, we need to know if CAS to reset status
         * fails, if so rechecking.
         */
        for (;;) { //自旋
            Node h = head;
            if (h != null && h != tail) {
                int ws = h.waitStatus;
                if (ws == Node.SIGNAL) {
                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    unparkSuccessor(h);
                }
                else if (ws == 0 &&
                         !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            if (h == head)                   // loop if head changed
                break;
        }
}

现在基本源码已经分析完毕了,只要理解了 AQS 和 CountDownLatch 的计数器到底是什么,就能够很好的理解 CountDownLatch 的原理了。

关注微信公众号:【入门小站】,解锁更多知识点

posted @ 2021-01-30 22:51  入门小站  阅读(103)  评论(0编辑  收藏  举报