不会还有人搞不懂CountDownLatch和CyclicBarrier吧?

323 阅读17分钟

不会还有人搞不懂CountDownLatch和CyclicBarrier吧?

在多线程开发过程中,线程之间的通信协调通常是程序员开发过程中比较头疼的问题,JDK给我们提供了两个非常方便简单的同步协调工具,就是CountDownLatch和CyclicBarrier,但是很多人对这两个工具还是一些误解。本文我们从源码的层面,来对这两个工具类内部实现进行剖析,最后来分析一下这两个工具类的使用场景都有哪些。

CountDownLatch

Introduction

CountDownLatch(同步计数器)作为JUC包中最常用的工具,主要的同步目的是使一个或多个线程可以阻塞等待其他作业线程执行完毕。

CountDownLatch的使用逻辑非常简单。通过给定的count值初始化一个计数器,该计数器作为门闩存在。等待线程可以调用awati()方法进行等待,直到count削减到0,才能往下执行。其他作业线程则通过countDown()方法进行削减,一般在线程任务完成的时候进行削减,推荐至少通过try-finally方式进行,避免线程执行异常未成功削减,导致等待线程无限挂起(等待可以设置超时)。

除了在作业线程执行完毕后调用其持有的同步计数器外,我们也可以通过在作业线程执行开始前,通过传入的另一个同步计数器来保证所有线程同步进行。以下是JDK源码中的Demo:

class Driver { // ...
    void main() throws InterruptedException {
        CountDownLatch startSignal = new CountDownLatch(1);
        CountDownLatch doneSignal = new CountDownLatch(N);

        for (int i = 0; i < N; ++i) // create and start threads
            new Thread(new Worker(startSignal, doneSignal)).start();

        doSomethingElse();            // don't let run yet
        startSignal.countDown();      // let all threads proceed
        doSomethingElse();
        doneSignal.await();           // wait for all to finish
    }
}

class Worker implements Runnable {
    private final CountDownLatch startSignal;
    private final CountDownLatch doneSignal;

    Worker(CountDownLatch startSignal, CountDownLatch doneSignal) {
        this.startSignal = startSignal;
        this.doneSignal = doneSignal;
    }

    public void run() {
        try {
            startSignal.await();
            doWork();
            doneSignal.countDown();
        } catch (InterruptedException ex) {
        } // return;
    }

    void doWork() {...}
}

但一般情况下我们最常使用的场景仍然是将工作拆分成子任务,提交到线程池中执行,并在主线程中调用await()等待作业执行完毕。

class Driver2 { // ...
    void main() throws InterruptedException {
        CountDownLatch doneSignal = new CountDownLatch(N);
        Executor e = ...

        for (int i = 0; i < N; ++i) // create and start threads
            e.execute(new WorkerRunnable(doneSignal, i));

        doneSignal.await();           // wait for all to finish
    }
}

class WorkerRunnable implements Runnable {
    private final CountDownLatch doneSignal;
    private final int i;

    WorkerRunnable(CountDownLatch doneSignal, int i) {
        this.doneSignal = doneSignal;
        this.i = i;
    }

    public void run() {
        try {
            doWork(i);
            doneSignal.countDown();
        } catch (InterruptedException ex) {
        } // return;
    }

    void doWork() { ...}
}

Structure

首先我们看一下源码结构,CountDownLatch的源码非常简单,加上注释总共才300+行,因此我们不用觉得这是一个很复杂的东西,说不定让你自行实现也能成功搞出来。

我们先来看一下它为数不多的几个方法和结构。

image-20200829114950503
image-20200829114950503

相信大家也是一眼就能看出来它的实现方式了。没错,和大多数JUC包的工具一样,CountDownLatch也是通过AQS抽象队列同步器实现的。

image-20200829115252559
image-20200829115252559

Sync子类作为CountDownLatch的同步控制器存在,并且使用AQS的state变量来表示count值(即创建CountDownLatch时传递的数量)。

Construction

我们先来看一下它的构造器。CountDownLatch只有一个构造器,允许传值一个合法(大于0)的count值。

image-20200829115454669
image-20200829115454669

对于子类Sync来说,就是将AQS的state状态设置为count值。对于CountDownLatch查看当前计数值的方法,即getCount(),也是直接过去AQS的state状态。

image-20200829115555661
image-20200829115555661

countDown

我们先从计数削减开始学习。

image-20200829120630345
image-20200829120630345

该方法作用即削减1个计数,当计数达到0时,释放(唤醒)所有等待线程。如果当前的state本来就是0,则什么都不发生。

同样的,我们进入AQS的releaseShared()方法看看计数削减的逻辑。

image-20200829121052506
image-20200829121052506

对于CountDownLatch的逻辑使用场景来说,同时会有多个线程对该计数进行削减,那么AQS肯定是通过共享模式实现,而不是独占模式。首先我们知道,doReleaseShared()方法即使通过共享模式来释放(唤醒)当前阻塞队列中的线程,那么什么时候会进行这个唤醒操作呢,即尝试以共享模式释放state返回true时(对于CountDownLatch的实现逻辑来说,即是计数归0时)。我们看一下AQS由子类Sync实现的削减计数方法tryReleaseShared(int)。

image-20200829231414481
image-20200829231414481

首先开启循环体(自旋CAS),尝试获取当前的state计数。如果c==0,表示什么都不发生,这与countDown()方法注释对应;否则尝试CAS进行state的替换削减操作。这个方法有个比较搞笑的地方,就是releases变量没有被使用,因为计数总是被单个单个地削减。

await()

理完countDown()方法的思路之后,我们来看一下await()。和countDown()方法不一样的顺序,我们先从AQS子类Sync实现的tryAcquiredShared()方法看。

image-20200829232726430
image-20200829232726430

可以看到,返回状态有两种。如果当前计数为0,则返回1,否则返回-1。我们只需要先记住这个方法的返回值,然后来看await()方法。

image-20200829232824105
image-20200829232824105

我们看到await()方法注释比较多,大概的意思就是描述该方法会使当前线程等待直到计数归零,除非线程被中断,也即是只有计数归零和线程中断两种情况会导致该方法从休眠状态中恢复。

await()方法调用的是AQS的acquireSharedInterruptibly()方法。

image-20200829233053939
image-20200829233053939

在这里会判断一次线程状态,是否抛出InterruptedException。并且我们看到,这里使用了tryAcquireShared()模板方法,也就是我们刚刚首先看的方法。在返回值小于0,即当前计数器不为0时,真正调用doAcquireSharedInterruptibly()方法;否则直接返回。

我们稍微看一下doAcquireSharedInterruptibly()方法(具体的放在以后AQS学习中说)。

image-20200829233849809
image-20200829233849809

首先是向队列中以共享模式新增了一个节点,并且开启循环体。通过判断当前节点的前驱节点是否为头节点(即是否),如果是头节点,则说明没有等待节点,再次尝试获取资源,如果返回>=0(返回1,即当前计数已归零),则进行头节点设置,并向链表传播该事件。

带超时的await()方法和await()差不多,不再介绍。

CyclicBarrier

Introduction

CyclicBarrier直译大概叫循环壁垒,习惯还是称之为栅栏。其主要功能是使一系列线程可以通过在同一个栅栏处停下等待(直到所有线程都达到栅栏),并且其实可循环使用的,即不需要重新创建新的栅栏,在触发一次完整的栅栏之后,栅栏的计数会恢复。

源码注释中有一个demo。

class Solver {
    final int N;
    final float[][] data;
    final CyclicBarrier barrier;

    class Worker implements Runnable {
        int myRow;

        Worker(int row) {
            myRow = row;
        }

        public void run() {
            while (!done()) {
                processRow(myRow);

                try {
                    barrier.await();
                } catch (InterruptedException ex) {
                    return;
                } catch (BrokenBarrierException ex) {
                    return;
                }
            }
        }
    }

    public Solver(float[][] matrix) {
        data = matrix;
        N = matrix.length;
        Runnable barrierAction =
                new Runnable() {
                    public void run() {
                        mergeRows(...);
                    }
                };
        barrier = new CyclicBarrier(N, barrierAction);

        List<Thread> threads = new ArrayList<Thread>(N);
        for (int i = 0; i < N; i++) {
            Thread thread = new Thread(new Worker(i));
            threads.add(thread);
            thread.start();
        }

        // wait until done
        for (Thread thread : threads)
            thread.join();
    }
}

该demo主要介绍了CyclicBarrier的其中一种使用方法,即通过构造器传入计数值外,还传入了一个Runnable,该Runnable的定义为在所有线程都达到栅栏时,调用该Runnable。在以上的例子中,即是每个工作线程生成矩阵的一行,并且在栅栏处阻塞等待,当所有工作线程都阻塞在栅栏时,调用传入的barrierAction,使矩阵合并。

CyclicBarrier使用了all-or-none breakage model中断模型,即一个线程因为中断、失败或超时等其他原因中断,则所有阻塞在该栅栏的线程都会通过BrokenBarrierException或InterruptedException进行异常退出。

Structure

CyclicBarrier的代码也非常的简单,加上注释总共不到500行。我们同样也是先从这个类的整体结构开始看。

image-20200831223431030
image-20200831223431030

可以看到,这个类有两个构造器。几个公共的方法放在后面来,我们先来看一下几个变量。

image-20200831224010151
image-20200831224010151

首先是最下面的count变量,可以看到该变量标识当前有多少参与者(线程)正在等待,和parties变量,表示该栅栏限制的参与者数量。

接着是ReentrantLock和Condition,lock作为内部变量来保证同一时间只有单个线程可以对count变量进行操作。

再是Generation,该类是CyclicBarrier的一个内部类,用来标识栅栏翻新的代数,其内部只有一个标记该栅栏是否损坏的布尔值变量(CyclicBarrier通过该状态来实现之前提到的all-or-none breakage model中断模型)。

最后是一个Runnable,保存栅栏清空时执行的操作。

Construction

同样先来看一下其构造函数。

image-20200903230711818
image-20200903230711818

构造函数并没有特殊的处理,只是初始化必须的变量,区别在于是否指定栅栏清空时的操作。

await()

CyclicBarrier只有一个比较关键的方法,就是await()。

image-20200903230949801
image-20200903230949801

方法注释较多,我这边不把所有的注释截图,大概介绍一下。

该方法主要功能就是阻塞直到该所有parties都调用了await()方法;在没有所有线程都达到栅栏处时,线程会一直保持休眠状态,直到1. 所有线程都到达栅栏、2. 其他线程中断当前线程、3. 其他线程中断其中一个等待线程(all-or-none breakage model)、4. 其他线程调用该栅栏的reset()方法;任何等待线程被中断,都会导致其他线程抛出BrokenBarrierException异常,并将该栅栏置为broken损坏状态;如果当前线程是最后一个到达栅栏处的线程,则在唤醒其他线程之前,会执行barrierCommand。

可以看到,await()方法实际调用的是dowait()方法。该方法较长,我们可以分两段来看。

image-20200903232726174
image-20200903232726174

首先通过lock变量获取到锁,在实际操作栅栏变量之前,对栅栏状态和线程状态进行检查。接下来就是很关键的count变量的削减逻辑,因为是在加锁代码块,不会出现并发影响,直接对count进行递减判断是否是最后一个线程,如果是最后一个线程,则进行barrierCommand的调用,开启新一代的栅栏。

image-20200903233131228
image-20200903233131228

接下来是第二段,这一段属于所有等待线程的通用逻辑。开启一个循环体,使等待线程在遇到栅栏清空、损坏、线程中断、超时之前,一致进行循环等待。实际在这一块的逻辑只有try{}代码块中的4行,调用trip Condition的await()方法等待,其他的都是对当前栅栏状态和线程状态的检查。

dowait()方法大概的逻辑走下来之后非常简单,但是在其中我们可以看到有两个比较关键的方法,开启下一代栅栏的nextGeneration()和打破栅栏的breakBarrier()。前面我们提到等待线程会通过trip Condition的await()方法进行等待,而在这两个方法中,就是将等待线程唤醒的逻辑。

我们先看正常唤醒的nextGeneration()方法。

image-20200903233703740
image-20200903233703740

代码非常简单,唤醒trip Condition的所有等待线程,并将栅栏初始化开启下一代。此时所有被阻塞在栅栏处的线程继续往下执行。

再来看一下await()逻辑中多次出现的breakBarrier()方法。

image-20200903233859627
image-20200903233859627

同样非常简单,将当前栅栏状态置为broken破损状态,重置栅栏并唤醒其他等待线程。但这个方法和nextGenration()方法有些不同,线程的唤醒操作在最后一步。我们回到dowait()方法看看。

可以看到,在breakBarrier()方法之后,都会对当前栅栏的broken破损状态进行检查,在适当的时候抛出BrokenBarrierException异常。也就是说,generation成员的broken变量必须在唤醒所有等待线程之前修改,等待线程才能正确的识别当前栅栏的状态。

以上就是CyclicBarrier的await()方法,带超时判断的await(long, TimeUnit)实际也是调用的dowait()方法。

image-20200903234429907
image-20200903234429907

CyclicBarrier的其他方法为栅栏状态的查询和重置方法,都较为简单,就不介绍了,大家可以花两分钟时间看看。

Conclusion

在阅读本文之前,估计很多人提起这两个线程协调类,都很难讲出他们两个的区别和分别的应用场景,但根据源码阅读一遍之后,就能够比较清晰的对这两个类有比较深刻的认识。

首先从实现上来说,CountDownLatch是通过AQS抽象队列同步器实现,使用CAS对计数值进行削减,所有线程(一般情况下不会有非常多线程等待)通过阻塞队列的形式保存,在计数值到达0的时候进行队列的传播唤醒。而CyclicBarrier主要是通过互斥锁ReentrantLock和Condition实现,其没有一个显式提供的削减计数方法(例如CountDownLatch的countDown()方法),每个线程调用await()时都会通过Condition#await()进入等待状态,直到栅栏清空或异常(栅栏损坏或线程中断)才会唤醒Condition阻塞的线程。

其实从上面的实现上也能明显的看出,CountDownLatch的countDown()和await()方法,提供了一套非常方便的主-从线程协调机制,只有在从线程通过countDown()方法削减计数值之后,主线程才能跨过await()方法的阻塞,这种机制天然地适合任务拆分的场景,即只有从线程全部完成之后,主线程才能继续往下走。而CyclicBarrier却不同,它的await()方法没有非常明显的主-从关系,所有线程都属于同等地位,可以认为是使所有线程保持公平的一个栅栏锁,在有任何一个线程未就位时,所有线程都必须等待而不能提前往下执行,适合多线程(例如上面提到的任务拆分从线程)之间的统一协调,例如测试并发请求,等待所有线程同时执行请求命令。

这样一来,大家应该就能在合适的多线程协调场景使用合适的同步工具了吧~