基于核心源码和个人思考——CyclicBarrier

284 阅读2分钟

CyclicBarrier

  • 初始化时给定计数值,并分派给线程,一开始计数为0,线程执行完后数值加1并阻塞,当计数值为指定数值时后,所有线程继续执行。
  • 实现方式是通过ReentranLock下Condition类的await方法和signalAll方法实现。

内部类Generation

表示当前屏障的状态

Generation属性

  • boolean broken = false 屏障是否打破

属性

  • ReentrantLock lock = new ReentrantLock() 可重入锁
  • Condition trip = lock.newCondition() 锁对应状态
  • int parties 屏障大小,即需要击破屏障的线程数
  • Runnable barrierCommand
  • Generation generation = new Generation()
  • int count 未到达屏障线程数量

接口

公开接口

  • int await() 计数值加1,阻塞致计数值为指定数
    • 核心是调用dowait()方法
  • int await(long timeout, TimeUnit unit) 在指定时间内阻塞
  • getNumberWaiting() 获取到达屏障的线程数
  • getParties() 获取设置释放线程的指定数
  • reset() 重置计数值为0,生成新一代屏障。
public void reset() {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        breakBarrier();   // break the current generation
        nextGeneration(); // start a new generation
    } finally {
        lock.unlock();
    }
}

核心接口

  • int dowait(boolean timed, long nanos) 拦截线程,若count达到0,则全部释放,重置count,生成新的屏障状态
private int dowait(boolean timed, long nanos)
    throws InterruptedException, BrokenBarrierException, TimeoutException {
    final ReentrantLock lock = this.lock;
    lock.lock(); // 加锁
    try {
        final Generation g = generation;
        // 屏障已经被打破,抛出异常
        if (g.broken)
            throw new BrokenBarrierException();
        
        // 线程被中断
        if (Thread.interrupted()) {
            breakBarrier(); // 屏障击破状态设为true,释放所有线程,重置count
            throw new InterruptedException();
        }

        int index = --count; // 没调用dowait(),count减一
        
        // index==0表示线程都到齐
        if (index == 0) { 
            boolean ranAction = false;
            try {
                final Runnable command = barrierCommand;
                if (command != null)
                    command.run(); // 如果有屏障击破事件,则执行
                ranAction = true;  // 确保屏障击破事件能顺利执行,没有被异常中断
                nextGeneration(); // 唤醒阻塞的线程,重置计数,生成新一代屏障
                return 0;
            } finally {
                if (!ranAction)
                    breakBarrier();
            }
        }
        
        // 线程没到齐
        for (;;) {
            try {
                if (!timed)
                    trip.await(); // 没有超时设置,直接阻塞当前线程
                else if (nanos > 0L)
                    nanos = trip.awaitNanos(nanos); // 阻塞当前线程指定时间
            } catch (InterruptedException ie) {
                if (g == generation && ! g.broken) {
                    breakBarrier();
                    throw ie;
                } else {
                    Thread.currentThread().interrupt();
                }
            }

            if (g.broken)
                throw new BrokenBarrierException();

            if (g != generation)
                return index;

            if (timed && nanos <= 0L) {
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        lock.unlock(); // 解锁
    }
}
  • void breakBarrier() 屏障击破状态设为true,唤醒阻塞的线程,重置count
private void breakBarrier() {
    generation.broken = true; // 屏障破裂状态设为true
    count = parties;  // 重置计数
    trip.signalAll();  // 唤醒所有阻塞线程
}
  • void nextGeneration() 唤醒阻塞的线程,重置计数,生成新一代屏障
private void nextGeneration() {
    trip.signalAll();  // Condition trip = lock.newCondition() 锁对应状态
    count = parties;  // 重置计数
    generation = new Generation();  // 唤醒所有阻塞线程
}