一、实现原理
CountDownLatch 基于 AQS(AbstractQueuedSynchronizer,抽象队列同步器)+ CAS 实现。构造时指定计数值 count,内部类 Sync 将其赋值给 AQS 的 state。
countDown()通过 CAS 原子将 state 减 1,state 为 0 时唤醒等待队列中的线程。await()判断 state:大于 0 则阻塞当前线程并加入等待队列;state 为 0 时直接放行。
⚠️注意:
✔️CountDownLatch 有同步队列优化(基于 AQS)
✔️但只有调用 await() 的线程才会进入同步队列,实际阻塞的是主线程
✔️调用 countDown() 的线程不会进入同步队列,它们只是 CAS 修改 state
二,使用场景
CountDownLatch 内部维护了一个计数器 **(AbstractQueuedSynchronizer的state属性值) **,该计数器初始值为 N,代表需要等待的线程数目,当一个线程完成了需要等待的任务后,就会调用 countDown() 方法将计数器减 1,当计数器的值为 0 时,等待的线程就会开始执行。
1、汇总场景
package com.nl;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ThreadLocalRandom;
public class CountDownLatchDemo {
public static void main(String[] args) throws Exception {
CountDownLatch countDownLatch = new CountDownLatch(3);
for (int i = 0; i < 5; i++) {
new Thread(() -> {
try {
Thread.sleep(1000 + ThreadLocalRandom.current().nextInt(2000));
System.out.println(Thread.currentThread().getName() + ":任务执行完成,当前计数器值为:" + countDownLatch.getCount());
countDownLatch.countDown();
} catch (InterruptedException e) {
e.printStackTrace();
}
},"线程" + i).start();
}
// 主线程在阻塞,当计数器为0,就唤醒主线程往下执行
countDownLatch.await();
System.out.println(Thread.currentThread().getName() + ":所有任务运行完成,结果汇总");
}
}
2、结果
线程0:任务执行完成,当前计数器值为:3
线程4:任务执行完成,当前计数器值为:2
线程1:任务执行完成,当前计数器值为:1
main:所有任务运行完成,结果汇总
线程2:任务执行完成,当前计数器值为:0
线程3:任务执行完成,当前计数器值为:0
⚠️注意: ✔️
✔️线程调用countDown()会通过 CAS 原子减少 state,state 减到 0 时唤醒 await 阻塞的线程;即使调用countDown()的次数超过初始值,也不影响其他线程,且不会再修改 state。
三、构造函数
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
Sync(int count) {
setState(count);
}
⚠️注意:
✔️**信号量赋值到AbstractQueuedSynchronizer的state属性值, **当计数器的值为 0 时,调用await的等待线程就会开始执行
四 、 常用方法
1、await
AbstractQueuedSynchronizer的state属性不为0,阻塞
// 调用 await() 方法的线程会被挂起,它会等待直到 count 值为 0 才继续执行
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
// 和 await() 类似,若等待 timeout 时长后,count 值还是没有变为 0,不再等待,继续执行
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
// 获取共享锁,并允许其中断
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
// 获取锁,由CountDownLatch实现,
// state > 0:有线程在持有锁资源,将当前线程添加到AQS等待队列
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
protected int tryAcquireShared(int acquires) {
// 线程全部执行完成,返回 1;未全部执行完成,返回-1
return (getState() == 0) ? 1 : -1;
}
// 当前线程加入等待队列并阻塞,直到获取锁或被中断
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
//是否轮到当前节点获取锁
if (p == head) {
// 调用 CountDownLatch 实现的方法
int r = tryAcquireShared(arg);
// 返回值为1,表示 state 为 0 ,线程都释放了锁,无其他线程持有锁
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
// shouldParkAfterFailedAcquire 判断是否需要将当前线程挂起
// parkAndCheckInterrupt注释当前线程,同时返回当前线程中断状态,如果true
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
⚠️注意:
✔️ 获取共享锁,并允许线程中断 **(Thread.interrupted()), **线程中断,抛出异常
✔️调用await()的线程会进入 AQS 阻塞队列,只要 state > 0 就持续阻塞,直到countDown()将 state 减为 0 才被唤醒。
✔️即使执行线程数大于 state 初始值,也只会将 state 减到 0,之后countDown()不再修改,await()线程正常唤醒。
2、countDown
计数器减一
public void countDown() {
sync.releaseShared(1);
}
public final boolean releaseShared(int arg) {
// 尝试释放锁
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}
protected boolean tryReleaseShared(int releases) {
for (;;) {
//获取当前持有锁资源的线程数
int c = getState();
// state已为0,返回false
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
// 没有线程持有锁,返回true
return nextc == 0;
}
}
private void doReleaseShared() {
for (;;) {
Node h = head;
// 判断头节点 h 是否不为空且不是尾节点 tail
if (h != null && h != tail) {
int ws = h.waitStatus;
// 如果头节点的等待状态为 Node.SIGNAL
if (ws == Node.SIGNAL) {
//使用 CAS 将其状态从 Node.SIGNAL 设置为 0,如果 CAS 失败,则继续循环重新检查。
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue;
//成功后调用 unparkSuccessor(h) 唤醒后继节点。
unparkSuccessor(h);
}
//如果头节点的等待状态为 0:
//使用 CAS 将其状态设置为 Node.PROPAGATE。
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue;
}
//如果头节点未变,则跳出循环。
//如果头节点已变,说明有其他线程修改了队列结构,需继续循环处理
if (h == head)
break;
}
}
⚠️注意:
✔️执行线程数大于 state 初始值,也只会将 state 减到 0,之后 countDown() 不再修改,await() 线程正常唤醒。
五、总结
CountDownLatch 基于 AQS+CAS 实现:
- 构造时指定 count,通过内部 Sync 类将 count 赋值给 AQS 的 state;
- countDown () 本质是 CAS 将 state 原子减 1,state 归 0 时唤醒等待线程;
- await () 判断 state 是否为 0,state>0 则阻塞线程,state=0 时结束阻塞。