CountDownLatch源码解析(基于JDK8)
1 介绍
CountDownLatch是一种AQS共享锁,可以看之前的介绍AQS(二)共享锁(基于JDK 8)
CountDownLatch 定义了一个计数器,和一个阻塞队列, 当计数器的值递减为0之前,阻塞队列里面的线程处于挂起状态,当计数器递减到0时会唤醒阻塞队列所有线程,这里的计数器是一个标志,可以表示一个任务一个线程,也可以表示一个倒计时器,CountDownLatch可以解决那些一个或者多个线程在执行之前必须依赖于某些必要的前提业务先执行的场景。
CountDownLatch 是一个共享锁,其中的 Sync 继承了 AQS,实现了共享的两个方法 tryAcquireShared 和 tryReleaseShared。
这里面计数器记录的值就是 state,初始化后只能不断减少,直到减为0,则唤醒堵塞队列的所有线程。 tryReleaseShared 会减少,而 tryAcquireShared 不会增加。
public class CountDownLatch {
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
// 初始化
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
// 重写父类方法,用于获取资源
// 可以看到,这里没有执行获取资源的操作
// 也就是在初始化后,资源只能释放减少,不能获取
// 返回 1 表示 acquire 成功,-1 表示失败。
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
// 重写父类方法,用于释放资源
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
// 如果已经是0,直接返回false
// 释放失败
if (c == 0)
return false;
int nextc = c-1;
// 其他情况,CAS 修改后,返回修改后的值是否为0
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
private final Sync sync;
}
2 CountDownLatch 的方法
下面是 CountDownLatch 的方法,不过是简单调用了 Sync 自己的和 AQS 的方法。
这里的方法称为 await,对比在 AQS 第三篇文章讲解的 Condition 的 await,可以猜出这几个 await 应该也会抛出中断异常。
当然,这里没有 signal。
// 初始化方法,要求 count>=0
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
public void countDown() {
sync.releaseShared(1);
}
public long getCount() {
return sync.getCount();
}
下面是用到的 AQS 的方法,前面在讲共享锁的时候讲过 acquireShared 和 releaseShared,对于这里的两种另外形式 acquire ,只分析其中的区别。
首先是 acquireSharedInterruptibly,对于中断抛出异常
// AQS 的方法
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
// 开始就检查是否中断
if (Thread.interrupted())
throw new InterruptedException();
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
// 对于该类的 Sync 来说,只会返回 1 和 -1
// 如果成功,说明此时state为0
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
// 这里直接抛出异常
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
// Sync 的方法
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
tryAcquireSharedNanos 对于中断会抛出异常,且超时会失败。
// AQS的方法
public final boolean tryAcquireSharedNanos(int arg, long nanosTimeout)
throws InterruptedException {
// 开始就检查是否中断
if (Thread.interrupted())
throw new InterruptedException();
return tryAcquireShared(arg) >= 0 ||
doAcquireSharedNanos(arg, nanosTimeout);
}
private boolean doAcquireSharedNanos(int arg, long nanosTimeout)
throws InterruptedException {
if (nanosTimeout <= 0L)
return false;
final long deadline = System.nanoTime() + nanosTimeout;
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
// 对于该类的 Sync 来说,只会返回 1 和 -1
// 如果成功,说明此时state为0
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return true;
}
}
// 计算剩余时间
nanosTimeout = deadline - System.nanoTime();
// 超时返回失败
if (nanosTimeout <= 0L)
return false;
// 应该堵塞返回 true,且剩余时间足够长,超过1000,
// 这时可以堵塞 nanosTimeout 这么长时间
if (shouldParkAfterFailedAcquire(p, node) &&
nanosTimeout > spinForTimeoutThreshold)
LockSupport.parkNanos(this, nanosTimeout);
if (Thread.interrupted())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
static final long spinForTimeoutThreshold = 1000L;
// Sync 的方法
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
最后是AQS 的 releaseShared,尝试释放,成功后唤醒后一个节点,并返回 true;失败返回 false。
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}
3 总结及使用
在前面可以看到,tryAcquireShared 只有在state为0才会返回 1,表示成功,且不会增加 state;tryReleaseShared 每次减少 state,并且成功的条件是 c 为1,CAS 修改后变成0。
下面举一个例子。
1 首先初始化 CountDownLatch,假设 state 为 3;
2 放入大量线程执行 await
,这些方法执行 tryAcquireShared
都是 -1,最终被堵塞在 parkAndCheckInterrupt
;
3 然后某个或某些线程执行了 countDown
3次,最后返回 true,会唤醒队列中后一个节点;
4 节点醒来后,执行 tryAcquireShared
返回1,成功,又会在 setHeadAndPropagate
设置头部并唤醒后一个节点,直到所有节点内的线程都被唤醒。
import java.util.concurrent.CountDownLatch;
public class B {
public static CountDownLatch countDownLatch = null;
public static void main(String[] args) throws Exception{
countDownLatch = new CountDownLatch(3);
Thread [] threads = new Thread[10];
for(int i = 0;i<threads.length;i++){
threads[i] = new MyThread(countDownLatch,"线程"+i);
threads[i].start();
}
// 在三次 countDown 后,所有线程都唤醒,
// 根据 debug 可知,队列中的线程不一定是 0-9 这个顺序,
// 这是因为从 await 到堵塞的位置,有很多指令要执行
// 随时会被其他线程打断
System.out.println("倒计时3");
countDownLatch.countDown();
System.out.println("倒计时2");
countDownLatch.countDown();
System.out.println("倒计时1");
countDownLatch.countDown();
// 多执行几次没有任何影响
//countDownLatch.countDown();
//countDownLatch.countDown();
}
}
class MyThread extends Thread{
private CountDownLatch countDownLatch = null;
public MyThread(CountDownLatch countDownLatch,String name){
super(name);
this.countDownLatch = countDownLatch;
}
@Override
public void run() {
try {
// 先进入队列等待
countDownLatch.await();
} catch(Exception e){
e.printStackTrace();
}
System.out.println(this.getName()+"开始运行");
}
}