并发编程之CountDownLatch类

925 阅读2分钟

1. 前言

在日常开发中,经常会遇到需要在主线程中开启多个线程去并行执行任务,并且主线程需要等待所有子线程执行完毕后再进行汇总的场景。

CountDownLatch 出现之前一般都使用 Thread的 join() 方法来实现这一点,但是join()方法并不灵活,不能够满足不同场景的需要,所以JDK 开发组 提供了CountDownLatch 这个类:

public static void CountDownLatchUsed() throws InterruptedException {
    CountDownLatch downLatch = new CountDownLatch(6);
    for (int i = 0; i < 6; i++) {
        new Thread(() -> {
            System.out.println(Thread.currentThread().getName() + "\t 下自习走人");
            downLatch.countDown();
        }, String.valueOf(i)).start();
    }
    downLatch.await();
    System.out.println(Thread.currentThread().getName() + "自习室关门走人");
}
---
输出结果
0	 下自习走人
5	 下自习走人
4	 下自习走人
2	 下自习走人
3	 下自习走人
1	 下自习走人
main自习室关门走人

2. 实现原理

CountDownLatch的名字就可以猜测其内部应该是个计数器,并且这个计数器是递减的。下面就通过源码看看JDK开发组是在 何时初始化计数器,在何时递减计数器,当计数器变为0时做了什么操作,多个线程时如何通过计时器实现同步的。为了一览 CountDownLatch 的内部结构,我们先看看类图:

image.png

从类图可以看出,CountDownLatch 是使用AQS实现的。通过下面的构造函数,你会发现,实际上是把计数器的值赋给了AQS的状态变量state,也就是这里使用AQS的状态值来表示计数器的值。

/**
 * Constructs a {@code CountDownLatch} initialized with the given count.
 *
 * @param count the number of times {@link #countDown} must be invoked
 *        before threads can pass through {@link #await}
 * @throws IllegalArgumentException if {@code count} is negative
 */
public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}

下面我们来研究 CountDownLatch 中的几个重要的方法,看它们是如何调用AQS来实现功能。

2.1 void await() 方法

#CountDownLatch/await

public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

#AQS/acquireSharedInterruptibly

public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    // 如何线程被终端,则抛出异常
    if (Thread.interrupted())
        throw new InterruptedException();
    
    // 查看当前计数器的值是否为0,为0则直接返回,否则进入AQS队列等待
    if (tryAcquireShared(arg) < 0)
        // 加入同步队列
        doAcquireSharedInterruptibly(arg);
}

#Sync类实现的AQS接口

protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
}

由以上代码可知,该方法的特点是线程获取资源时可以被中断,并且获取的资源时共享资源。acquireSharedInterruptibly 首先判断当前线程是否已被中断,若是则抛出异常,否则调用sync实现的 tryAcquireShared 方法查看当前状态值是否为0,是则当前线程的 await() 方法直接返回,否则调用 AQS 的doAcquireSharedInterruptibly 方法让当前线程阻塞。

另外可以看到,这里tryAcquireShared 传递的arg参数没有被用到,调用tryAcquireShared的方法仅仅是为了检查当前状态值是不是为0,并没有调用CAS让当前状态值减1.

2.2 void countDown() 方法

#CountDownLatch/countDown

/**
 * Decrements the count of the latch, releasing all waiting threads if
 * the count reaches zero.
 *
 * <p>If the current count is greater than zero then it is decremented.
 * If the new count is zero then all waiting threads are re-enabled for
 * thread scheduling purposes.
 *
 * <p>If the current count equals zero then nothing happens.
 */
public void countDown() {
    sync.releaseShared(1);
}

#sync/releaseShared

public final boolean releaseShared(int arg) {
    // sync实现的
    if (tryReleaseShared(arg)) {
        // AQS释放资源的方法
        doReleaseShared();
        return true;
    }
    return false;
}

#CountDownLatch/Sync/tryReleaseShared

protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    // 循环进行CAS,知道当前线程成功完成CAS使计数器的值(状态值state)减1并更新到state
    for (;;) {
        int c = getState();
        // 如何当前状态值为0则直接返回
        if (c == 0)
            return false;
        // 使用CAS使计数器的值减1
        int nextc = c-1;
        if (compareAndSetState(c, nextc))
            return nextc == 0;
    }
}

3. 总结

CountDownLatch 是使用 AQS 实现的。使用AQS的状态变量来存放计数器的值。首先在初始化 CountDownLatch 时设置状态值(计数器的值),当多个线程调用countdown方法时,实际是原自行递减AQS状态值。当线程调用await方法后当前线程会被放入AQS 的阻塞队列等待计数器为0再返回。

其他线程调用CountDownLatch方法让计数器的值递减1,当计数器值变为0时,当前线程还要调用AQS的 doReleaseShared 方法来激活由于调用await()方法而被阻塞的线程。