JUC中的线程同步器原理2-CyclicBarrier

91 阅读5分钟

CountDownLatch的局限性

CountDownLatch的计数器时一次性的,也就是等到计数器值变为0后,再调用CountDownLatch的await()和countDown()都会立即返回。 为了满足计数器可以重置的需要,JDK开发组提供了CyclicBarrier类,并且CyclicBarrier类的功能不限于CountDownLatch的功能:CyclicBarrier(回环屏障)可以让一组线程全部达到一个状态后再全部同时执行,叫做回环是因为当所有等待线程执行完毕,并重置CyclicBarrier的状态后可以被重用;叫做屏障是因为线程调用await方法后就会被阻塞,这个阻塞点称为屏障点;等所有线程都调用了await()后,线程们就会冲破屏障,继续向下运行。

demo实现1

实现场景:使用两个线程去执行一个被分解的任务A,当两个线程把自己的任务都执行完毕后再对它们的结果进行汇总处理。

使用CyclicBarrier实现

public class CyclicBarrierTest1 {
    // 创建一个CyclicBarrier实例,添加一个所有子线程都到达屏障后执行的任务
    private static CyclicBarrier cyclicBarrier =
            new CyclicBarrier(2, new Runnable() {
                @Override
                public void run() {
                    System.out.println(Thread.currentThread() + "task1 merge result");
                }
            });

    public static void main(String[] args) {
        // 创建一个线程个数固定为2的线程池
        ExecutorService executorService = Executors.newFixedThreadPool(2);
        // 将线程1添加到线程池
        executorService.submit(new Runnable() {
            @Override
            public void run() {
                try {
                    System.out.println(Thread.currentThread() + "task1-1");
                    System.out.println(Thread.currentThread() + "enter in barrier");
                    cyclicBarrier.await();
                    System.out.println(Thread.currentThread() + "enter out barrier");
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        });
        // 将线程1添加到线程池
        executorService.submit(new Runnable() {
            @Override
            public void run() {
                try {
                    System.out.println(Thread.currentThread() + "task1-2");
                    System.out.println(Thread.currentThread() + "enter in barrier");
                    cyclicBarrier.await();
                    System.out.println(Thread.currentThread() + "enter out barrier");
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        });
        // 关闭线程池
        executorService.shutdown();
    }
}

使用CountDownLatch实现

public class CountDownLatchTest {
    private static volatile CountDownLatch countDownLatch =
            new CountDownLatch(2);

    public static void main(String[] args) throws InterruptedException {
        ExecutorService executorService = Executors.newFixedThreadPool(2);
        executorService.submit(new Runnable() {
            @Override
            public void run() {
                try{
                    System.out.println(Thread.currentThread() + "Task1-1");
                    System.out.println(Thread.currentThread() + "countDown");
                    countDownLatch.countDown();
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        });
        executorService.submit(new Runnable() {
            @Override
            public void run() {
                try{
                    System.out.println(Thread.currentThread() + "Task1-2");
                    System.out.println(Thread.currentThread() + "countdown");
                    countDownLatch.countDown();
                } catch (Exception e) {
                    e.printStackTrace();
                }
            }
        });
        countDownLatch.await();
        System.out.println(Thread.currentThread()+"task1 merge result");
        executorService.shutdown();
    }
}

从这里可以看到CountDownLatch和CyclicBarrier的细微差别:

  1. 在CyclicBarrier中,子线程阻塞,直到所有子线程都调用了await()使计数器为0,此时唤醒所有子线程继续往下运行
  2. 在CountDownLatch中,主线程调用await()进入阻塞,直到所有子线程都调用了countDown()才会唤醒主线程继续往下执行
  3. 也就是说,阻塞的主体不同:CyclicBarrier中调用await()的是子线程,CountDownLatch中调用await()的是主线程,因而阻塞的主体不同;

demo实现2

实现场景:一个任务由阶段1、阶段2和阶段3组成,每个线程要串行执行阶段1、阶段2和阶段3,当多个线程执行该任务时,必须要保证所有线程的阶段1全部完成后才能进入阶段2执行,当所有线程的阶段2全部完成后才能进入阶段3执行。

public class CyclicBarrierTest2 {
    // 创建一个CyclicBarrier实例
    private static CyclicBarrier cyclicBarrier =
            new CyclicBarrier(2);

    public static void main(String[] args) {
        ExecutorService executorService = Executors.newFixedThreadPool(2);
        // 将线程1添加到线程池
        executorService.submit(new Runnable() {
            @Override
            public void run() {
                try {
                    System.out.println(Thread.currentThread() + "step1");
                    cyclicBarrier.await();
                    System.out.println(Thread.currentThread() + "step2");
                    cyclicBarrier.await();
                    System.out.println(Thread.currentThread() + "step3");
                } catch (Exception e){
                    e.printStackTrace();
                }

            }
        });
        // 将线程2添加到线程池
        executorService.submit(new Runnable() {
            @Override
            public void run() {
                try {
                    System.out.println(Thread.currentThread() + "step1");
                    cyclicBarrier.await();
                    System.out.println(Thread.currentThread() + "step2");
                    cyclicBarrier.await();
                    System.out.println(Thread.currentThread() + "step3");
                } catch (Exception e){
                    e.printStackTrace();
                }

            }
        });
        executorService.shutdown();
    }
}

image.png 每个子线程在执行完阶段1后都调用了await()方法,等到所有线程都到达屏障点后才会一起往下执行,保证了所有线程都完成了阶段1后才会开始执行阶段2;在阶段2后调用了await方法,保证了所有线程都完成了阶段2后,才能开始阶段3的执行。这个功能如果使用CountDownLatch,则需要两个CountDownLatch变量才能完成:

    ···
    System.out.println(Thread.currentThread() + "step1");
    countDownLatch1.countDown();
    System.out.println(Thread.currentThread() + "step2");
    countDownLatch2.countDown();
    System.out.println(Thread.currentThread() + "step3");
    ···

CyclicBarrier原理概述

image.png

  • ReentrantLock:实现计数器原子性更新
  • parties: 记录线程个数,表示多少线程调用await后,所有线程才会冲破屏障继续往下执行
  • count: 一开始等于parties,每当有线程调用await方法就递减1,当count为0时表示所有线程都到了屏障点
  • generation:内部有一个变量broken,用来记录当前屏障是否被打破
  • await():调用dowait()
  • dowait():
// Main barrier code, covering the various policies.
private int dowait(boolean timed, long nanos)
    throws InterruptedException, BrokenBarrierException,
           TimeoutException {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        final Generation g = generation;

        if (g.broken)
            throw new BrokenBarrierException();

        if (Thread.interrupted()) {
            breakBarrier();
            throw new InterruptedException();
        }

        int index = --count;
        // 如果index==0,说明所有线程都到了屏障点,此时执行初始化时传递的任务
        if (index == 0) {  // tripped
            boolean ranAction = false;
            try {
                // 执行任务
                final Runnable command = barrierCommand;
                if (command != null)
                    command.run();
                ranAction = true;
                // 激活其他因调用await()方法而被阻塞的线程,并重置CyclicBarrier
                nextGeneration();
                // 返回
                return 0;
            } finally {
                if (!ranAction)
                    breakBarrier();
            }
        }

        // loop until tripped, broken, interrupted, or timed out
        // 如果index!=0
        for (;;) {
            try {
                // 没有设置超时时间
                if (!timed)
                    // 加入trip的条件变量队列
                    trip.await();
                // 设置了超时时间
                else if (nanos > 0L)
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
                if (g == generation && ! g.broken) {
                    breakBarrier();
                    throw ie;
                } else {
                    // We're about to finish waiting even if we had not
                    // been interrupted, so this interrupt is deemed to
                    // "belong" to subsequent execution.
                    Thread.currentThread().interrupt();
                }
            }

            if (g.broken)
                throw new BrokenBarrierException();

            if (g != generation)
                return index;

            if (timed && nanos <= 0L) {
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        // 释放锁
        lock.unlock();
    }
}


// Updates state on barrier trip and wakes up everyone. Called only while holding lock.
private void nextGeneration() {
    // signal completion of last generation
    // 唤醒条件队列里面的阻塞线程
    trip.signalAll();
    // set up next generation
    // 重置CyclicBarrier
    count = parties;
    generation = new Generation();
}

  1. 首先获取独占锁lock,假设创建CyclicBarrier时传递的参数是5,那么后面4个调用线程会被阻塞;
  2. 然后当前获取到锁的线程会对计数器count进行递减操作,递减后count = 4;
  3. 不为0则当前线程放入条件变量trip的条件阻塞队列,当前线程会被挂起并释放获取的lock锁;
  4. 当第一个获取锁的线程由于被阻塞释放锁后,被阻塞的4个线程中有一个会竞争到lock锁,执行相同操作,直到最后一个线程获取到lock锁,此时已经有4个线程被放入了条件变量trip的条件队列里面;
  5. 最后count=0,如果创建CyclicBarrier时传递了任务,则在其他线程被唤醒前先执行任务,任务执行完毕后在唤醒其他4个线程,并重置CyclicBarrier,这5个线程就可以继续往下执行。 Question: CyclicBarrier是如何实现复用的?

parties始终用来记录总的线程个数,当count计数器值变为0后,会将parties的值赋给count,从而进行复用