【Java并发】 - CountDownLatch使用以及原理
概述
CountDownLatch是一个用来控制并发的很常见的工具,它允许一个或者多个线程等待其他的线程执行到某一操作,比如说需要去解析一个excel的数据,为了更快的解析则每个sheet都使用一个线程去进行解析,但是最后的汇总数据的工作则需要等待每个sheet的解析工作完成之后才能进行,这就可以使用CountDownLatch。
使用
例子:
这里有三个线程(main,thread1,thread2),其中main线程将调用countDownLatch的await方法去等待另外两个线程的某个操作的结束(调用countDownLatch的countDown方法)。
package 多线程并发; import java.util.concurrent.CountDownLatch; /** * Created by z84102272 on 2018/7/18. * CountDownLatch是一个用来控制并发的很常见的工具,它允许一个或者多个线程等待其他的线程执行到某一操作, * 比如说需要去解析一个excel的数据,为了更快的解析则每个sheet都使用一个线程去进行解析, * 但是最后的汇总数据的工作则需要等待每个sheet的解析工作完成之后才能进行,这就可以使用CountDownLatch。 */ public class CountDownLatchTest { public static void main(String[] args) throws InterruptedException { CountDownLatch countDownLatch = new CountDownLatch(2){ @Override public void await() throws InterruptedException { super.await(); System.out.println(Thread.currentThread().getName() + "count down is ok"); } }; Thread thread1 = new Thread(new Runnable() { @Override public void run() { try { Thread.sleep(1000); }catch (InterruptedException e){ e.printStackTrace(); } System.out.println(Thread.currentThread().getName() + "is done"); countDownLatch.countDown(); } },"thread1"); Thread thread2 = new Thread(new Runnable() { @Override public void run() { try{ Thread.sleep(2000); }catch (InterruptedException e){ e.printStackTrace(); } System.out.println(Thread.currentThread().getName() + "is done"); countDownLatch.countDown(); } },"thread2"); thread1.start(); thread2.start(); countDownLatch.await(); } }
这里的CountDownLatch的构造函数中使用的int型变量的意思是需要等待多少个操作 的完成。这里是2所以需要等到调用了两次countDown()方法之后主线程的await()方法才会返回。这意味着如果我们错误的估计了需要等待的操作的个数或者在某个应该调用countDown()方法的地方忘记了调用那么将意味着await()方法将永远的阻塞下去。
实现原理
CountDownLatch类实际上是使用计数器的方式去控制的,不难想象当我们初始化CountDownLatch的时候传入了一个int变量这个时候在类的内部初始化一个int的变量,每当我们调用countDownt()方法的时候就使得这个变量的值减1,而对于await()方法则去判断这个int的变量的值是否为0,是则表示所有的操作都已经完成,否则继续等待。
实际上如果了解AQS的话应该很容易想到可以使用AQS的共享式获取同步状态的方式来完成这个功能。而CountDownLatch实际上也就是这么做的。
从结构上来看CountDownLatch的实现还是很简单的,通过很常见的继承AQS的方式来完成自己的同步器。
CountDownLatch的同步器实现:
private static final class Sync extends AbstractQueuedSynchronizer { private static final long serialVersionUID = 4982264981922014374L; //初始化state Sync(int count) { setState(count); } int getCount() { return getState(); } //尝试获取同步状态 //只有当同步状态为0的时候返回大于0的数1 //同步状态不为0则返回-1 protected int tryAcquireShared(int acquires) { return (getState() == 0) ? 1 : -1; } //自旋+CAS的方式释放同步状态 protected boolean tryReleaseShared(int releases) { // Decrement count; signal when transition to zero for (;;) { int c = getState(); if (c == 0) return false; int nextc = c-1; if (compareAndSetState(c, nextc)) return nextc == 0; } } }
比较关键的地方是tryAquireShared()方法的实现,因为在父类的AQS中aquireShared()方法在调用tryAquireShared()方法的时候的判断依据是返回值是否大于零。
public final void acquireShared(int arg) { if (tryAcquireShared(arg) < 0) //失败则进入等待队列 doAcquireShared(arg); }
最后:由于CountDownLatch需要开发人员很明确需要等待的条件,否则很容易造成await()方法一直阻塞的情况。