浅析CountDownLatch源码
需要了解 AQS 知识。
CountDownLatch 能够等待一个或一组线程,直到其他线程执行完成(计数器减为 0)时,才继续执行。
其实调用线程的 join() 方法能够实现等待线程完成后再继续执行的场景。
不过 CountDownLatch 更为灵活:https://blog.csdn.net/zhutulang/article/details/48504487
CountDownLatch 实现的原理大致如下:
创建时传入计数器初始值,子任务完成时,AQS 中的 state 属性可以表示等待完成的任务数量,没完成一项计数 -1,计数器为 0 时,唤醒调用线程。
构造方法
CountDownLatch 的构造方法必须传入一个整形作为计数器的初始值,该数值用于初始化 Sync。
private final Sync sync;
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
属性
Sync
CountDownLatch 内部方法实现都调用了 Sync,可见 Sync 为该类核心。
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
// Sync 继承自 AQS,从名称可以看出初始化值传入 AQS 的 state 属性
setState(count);
}
int getCount() {
return getState();
}
// 使用了 AQS 的共享模式
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
// 通过自选操作实现自减 1
for (;;) {
// 获取更新的 state 值
int c = getState();
// 若无需释放锁(=0)
if (c == 0)
return false;
// 若释放锁则递减
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
方法
// 获取还需要等待的任务数量
public long getCount() {...}
public String toString() {...}
public void await()
调用 await() 方法后,调用线程会被阻塞,直到出现下面情况之一:
- 所有任务线程调用 countDown 方法,即计数器为 0
- 其他线程调用当前线程的 interrupt() 方法进行中断,此时会抛出异常
// 等待任务完成,计数器为 0 时返回
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
// 设置了超时时间
public boolean await(long timeout, TimeUnit unit) { ... }
acquireSharedInterruptibly 是 AQS 中定义的方法
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
// 线程被中断则抛出异常
if (Thread.interrupted())
throw new InterruptedException();
// 至少尝试一次 tryAcquireShared
// 成功:返回
// 失败:线程进入等待队列
if (tryAcquireShared(arg) < 0) // (1)
// 进入等待队列
doAcquireSharedInterruptibly(arg); // (2)
}
(1)CountDownLatch 中 sync 定义的方法,判断 state 是否为 0
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
doAcquireSharedInterruptibly 方法创建了共享模式的 AQS 节点进入等待队列进行排队。
CountDownLatch 设置 state 后未置为 0,调用 await 的线程都会进行等待。
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
// 创建共享模式节点,并加入等待队列
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
// 自旋尝试获得锁
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
// 当前线程是否需要挂起
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
public void countDown()
// 递减计数器,计数器等于 0,则释放所有等待的线程
public void countDown() { sync.releaseShared(1); }
AQS 中的 releaseShared 实现
public final boolean releaseShared(int arg) {
// 会执行 sync 的 tryReleaseShared 方法 -1 ,然后进行共享锁的释放操作
if (tryReleaseShared(arg)) { // (1)
doReleaseShared(); // (2)
return true;
}
return false;
}
(1)CountDownLatch 中 sync 定义的方法,判断 state 是否为 0
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false; // state 计数为 0 时,释放失败
int nextc = c-1;
if (compareAndSetState(c, nextc)) // 将 state 值使用 CAS 置为 state-1
return nextc == 0;
}
}