Java并发编程之CountDownLatch
简介
在日常的开发中,可能会遇到这样的场景:开启多个子线程执行一些耗时任务,然后在主线程汇总,在子线程执行的过程中,主线程保持阻塞状态直到子线程完成任务。
使用CountDownLatch类或者Thread类的join()方法都能实现这一点,下面通过例子来介绍这两种实现方法。
CountDownLatch的使用
一个小例子,等待所有玩家准备就绪,然后游戏才开始。
使用join方法实现:
public class Demo {
public static void main(String[] args) throws InterruptedException {
Runnable runnable = () -> {
System.out.println(Thread.currentThread().getName() + ":准备就绪");
};
Thread thread1 = new Thread(runnable, "一号玩家");
Thread thread2 = new Thread(runnable, "二号玩家");
Thread thread3 = new Thread(runnable, "三号玩家");
Thread thread4 = new Thread(runnable, "四号玩家");
Thread thread5 = new Thread(runnable, "五号玩家");
thread1.start();
thread2.start();
thread3.start();
thread4.start();
thread5.start();
//主线程等待子线程执行完成再执行
thread1.join();
thread2.join();
thread3.join();
thread4.join();
thread5.join();
System.out.println("---游戏开始---");
}
}
/*
* 输出结果:
* 二号玩家:准备就绪
* 五号玩家:准备就绪
* 四号玩家:准备就绪
* 三号玩家:准备就绪
* 一号玩家:准备就绪
* ---游戏开始---
*/
使用CountDownLatch实现:
public class Demo {
public static void main(String[] args) throws InterruptedException {
//创建计数器初始值为5的CountDownLatch
CountDownLatch countDownLatch = new CountDownLatch(5);
Runnable runnable = () -> {
try{
System.out.println(Thread.currentThread().getName() + ":准备就绪");
}catch (Exception ex){
ex.printStackTrace();
}finally {
//计数器值减一
countDownLatch.countDown();
}
};
Thread thread1 = new Thread(runnable, "一号玩家");
Thread thread2 = new Thread(runnable, "二号玩家");
Thread thread3 = new Thread(runnable, "三号玩家");
Thread thread4 = new Thread(runnable, "四号玩家");
Thread thread5 = new Thread(runnable, "五号玩家");
thread1.start();
thread2.start();
thread3.start();
thread4.start();
thread5.start();
//等待计数器值为0
countDownLatch.await();
System.out.println("---游戏开始---");
}
}
/*
* 输出结果:
* 四号玩家:准备就绪
* 五号玩家:准备就绪
* 一号玩家:准备就绪
* 三号玩家:准备就绪
* 二号玩家:准备就绪
* ---游戏开始---
*/
CountDownLatch内部包含一个计数器,计数器的初始值为CountDownLatch构造函数传入的int类型的参数,countDown方法会递减计数器值,await方法会阻塞当前线程直到计数器值为0。
两种方式的区别:
当调用子线程的join方法时,会阻塞当前线程直到子线程结束。而CountDownLatch相对比较灵活,无需等到子线程结束,只要计数器值为0,await方法就会返回。
CountDownLatch源码
CountDownLatch源码:
public class CountDownLatch {
/**
* CountDownLatch的同步控制,使用AQS的状态值作为计数器值。
*/
private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;
Sync(int count) {
setState(count);
}
int getCount() {
return getState();
}
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
int c = getState();
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}
private final Sync sync;
/**
* 构造函数,初始化计数器
*/
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
/**
* 阻塞当前线程直到计数器值为0
*/
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
/**
* 阻塞当前线程直到计数器值为0或者超时
*/
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
/**
* 递减计数器值,当计数器值为0时,释放所有等待的线程。
*/
public void countDown() {
sync.releaseShared(1);
}
/**
* 返回当前计数器值
*/
public long getCount() {
return sync.getCount();
}
public String toString() {
return super.toString() + "[Count = " + sync.getCount() + "]";
}
}
通过源码可以看出,CountDownLatch内部是使用AQS实现的,它使用AQS的状态变量state作为计数器值,静态内部类Sync继承了AQS并实现了tryAcquireShared和tryReleaseShared方法。
接下来重点看下await()和countDown()的源码:
await()方法内部调用的是AQS的acquireSharedInterruptibly方法,会将当前线程放入AQS队列等待,直到计数值为0。
public final void acquireSharedInterruptibly(int arg) throws InterruptedException {
//判断当前线程是否被中断,如果线程被中断则抛出异常
if (Thread.interrupted())
throw new InterruptedException();
//判断计数器值是否为0,为0则直接返回,否则进AQS队列进行等待。
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
//CountDownLatch中Sync的tryAcquireShared方法实现,直接判断计数器值是否为0。
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
countDown()方法内部调用的是AQS的releaseShared方法,每次调用都会递减计数值,直到计数值为0则调用AQS释放资源的方法。
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
//释放资源
doReleaseShared();
return true;
}
return false;
}
//CountDownLatch中Sync的tryReleaseShared方法实现
protected boolean tryReleaseShared(int releases) {
for (;;) {
int c = getState();
//计数值为0直接返回
if (c == 0)
return false;
//设置递减后的计数值
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}