CountDownLatch源码剖析
CountDownLatch
门闩,他可以让多个线程 都阻塞在⼀个地⽅,直到 所有线程任务都执⾏完成。
测试案例:
先让子线程执行完了,再让主线程执行
public class CountDownLatchDemo {
public static void main(String[] args) {
CountDownLatchDemo demo = new CountDownLatchDemo();
demo.test();
}
private void test() {
// 定义计数器值为10的CountDownLatch对象
CountDownLatch countDownLatch = new CountDownLatch(10);
// 创建10个子线程执行任务,最少需要10个子线程
for (int i = 0; i < 10; i++) {
new Thread(() -> this.task(countDownLatch)).start();
}
// 在此等待所有任务完成Ω
try {
countDownLatch.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println("Thread-main end...");
}
private void task(CountDownLatch countDownLatch) {
SleepUtils.sleep(new Random().nextInt(3));
// 任务执行完后将计数器减1
System.out.println(Thread.currentThread().getName() + " end...");
countDownLatch.countDown();
}
}
运行结果:
Thread-3 end...
Thread-1 end...
Thread-6 end...
Thread-2 end...
Thread-0 end...
Thread-4 end...
Thread-5 end...
Thread-8 end...
Thread-7 end...
Thread-9 end...
Thread-main end...
CountDownLatch常用方法源码剖析
public class CountDownLatch {
// 可以看到底层使用到了AQS
private static final class Sync extends AbstractQueuedSynchronizer {
Sync(int count) {
setState(count);
}
// ......
}
// 就是上面这个
private final Sync sync;
// count一般传入要锁住的线程的数量
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count); // 底层操作的是state变量,见上
}
//
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
// 小于0表示,还有线程没有执行完任务了
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
// 向队尾添加一个SHARED类型的节点
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {// 因为aqs的head是不存东西的,这里是主线程的Node节点,一定是位true的
int r = tryAcquireShared(arg);// 判断state的值
if (r >= 0) {// 表示所有的子线程都执行完了
setHeadAndPropagate(node, r);// 释放头结点
p.next = null; // help GC
failed = false;
return;
}
}
// 检查Node中的线程是否需要被挂起,如果返回true则说明需要挂起,然后执行后续挂起方法 parkAndCheckInterrupt,否则重新自旋。
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
int ws = pred.waitStatus;// waitStatus默认是0
if (ws == Node.SIGNAL) // Node.SIGNAL = -1,不走
/*
* This node has already set status asking a release
* to signal it, so it can safely park.
*/
return true;
if (ws > 0) { //不走
/*
* Predecessor was cancelled. Skip over predecessors and
* indicate retry.
*/
do {
node.prev = pred = pred.prev;
} while (pred.waitStatus > 0);
pred.next = node;
} else {// 进入到这里
/*
* waitStatus must be 0 or PROPAGATE. Indicate that we
* need a signal, but don't park yet. Caller will need to
* retry to make sure it cannot acquire before parking.
*/
compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
}
return false;
}
// countDown
public void countDown() {
sync.releaseShared(1);
}
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
// 总结:就是通过自旋的方式等待所有任务执行完了。
for (;;) {
int c = getState();
// 如果state==0表示已经没有线程再执行任务了
if (c == 0)
return false;
int nextc = c-1;
// 否则,通过cas的方式缉拿state的值-1,说白了就是释放一个执行完任务的线程
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}