1. 前言
在日常开发中,经常会遇到需要在主线程中开启多个线程去并行执行任务,并且主线程需要等待所有子线程执行完毕后再进行汇总的场景。
再CountDownLatch 出现之前一般都使用 Thread的 join() 方法来实现这一点,但是join()方法并不灵活,不能够满足不同场景的需要,所以JDK 开发组 提供了CountDownLatch 这个类:
public static void CountDownLatchUsed() throws InterruptedException {
CountDownLatch downLatch = new CountDownLatch(6);
for (int i = 0; i < 6; i++) {
new Thread(() -> {
System.out.println(Thread.currentThread().getName() + "\t 下自习走人");
downLatch.countDown();
}, String.valueOf(i)).start();
}
downLatch.await();
System.out.println(Thread.currentThread().getName() + "自习室关门走人");
}
---
输出结果
0 下自习走人
5 下自习走人
4 下自习走人
2 下自习走人
3 下自习走人
1 下自习走人
main自习室关门走人
2. 实现原理
从CountDownLatch的名字就可以猜测其内部应该是个计数器,并且这个计数器是递减的。下面就通过源码看看JDK开发组是在 何时初始化计数器,在何时递减计数器,当计数器变为0时做了什么操作,多个线程时如何通过计时器实现同步的。为了一览 CountDownLatch 的内部结构,我们先看看类图:
从类图可以看出,CountDownLatch 是使用AQS实现的。通过下面的构造函数,你会发现,实际上是把计数器的值赋给了AQS的状态变量state,也就是这里使用AQS的状态值来表示计数器的值。
/**
* Constructs a {@code CountDownLatch} initialized with the given count.
*
* @param count the number of times {@link #countDown} must be invoked
* before threads can pass through {@link #await}
* @throws IllegalArgumentException if {@code count} is negative
*/
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
下面我们来研究 CountDownLatch 中的几个重要的方法,看它们是如何调用AQS来实现功能。
2.1 void await() 方法
#CountDownLatch/await
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
#AQS/acquireSharedInterruptibly
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
// 如何线程被终端,则抛出异常
if (Thread.interrupted())
throw new InterruptedException();
// 查看当前计数器的值是否为0,为0则直接返回,否则进入AQS队列等待
if (tryAcquireShared(arg) < 0)
// 加入同步队列
doAcquireSharedInterruptibly(arg);
}
#Sync类实现的AQS接口
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
由以上代码可知,该方法的特点是线程获取资源时可以被中断,并且获取的资源时共享资源。acquireSharedInterruptibly 首先判断当前线程是否已被中断,若是则抛出异常,否则调用sync实现的 tryAcquireShared 方法查看当前状态值是否为0,是则当前线程的 await() 方法直接返回,否则调用 AQS 的doAcquireSharedInterruptibly 方法让当前线程阻塞。
另外可以看到,这里tryAcquireShared 传递的arg参数没有被用到,调用tryAcquireShared的方法仅仅是为了检查当前状态值是不是为0,并没有调用CAS让当前状态值减1.
2.2 void countDown() 方法
#CountDownLatch/countDown
/**
* Decrements the count of the latch, releasing all waiting threads if
* the count reaches zero.
*
* <p>If the current count is greater than zero then it is decremented.
* If the new count is zero then all waiting threads are re-enabled for
* thread scheduling purposes.
*
* <p>If the current count equals zero then nothing happens.
*/
public void countDown() {
sync.releaseShared(1);
}
#sync/releaseShared
public final boolean releaseShared(int arg) {
// sync实现的
if (tryReleaseShared(arg)) {
// AQS释放资源的方法
doReleaseShared();
return true;
}
return false;
}
#CountDownLatch/Sync/tryReleaseShared
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
// 循环进行CAS,知道当前线程成功完成CAS使计数器的值(状态值state)减1并更新到state
for (;;) {
int c = getState();
// 如何当前状态值为0则直接返回
if (c == 0)
return false;
// 使用CAS使计数器的值减1
int nextc = c-1;
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
3. 总结
CountDownLatch 是使用 AQS 实现的。使用AQS的状态变量来存放计数器的值。首先在初始化 CountDownLatch 时设置状态值(计数器的值),当多个线程调用countdown方法时,实际是原自行递减AQS状态值。当线程调用await方法后当前线程会被放入AQS 的阻塞队列等待计数器为0再返回。
其他线程调用CountDownLatch方法让计数器的值递减1,当计数器值变为0时,当前线程还要调用AQS的 doReleaseShared 方法来激活由于调用await()方法而被阻塞的线程。