CyclicBarrier源码解析

394 阅读3分钟

上一期介绍了CountDownLatch的源码解析,CountDownLatch有几个问题:首先CountDownLatch在await之后必须依靠别的线程来给它countDown,打开门闩;其次CountDownLatch在countDown到0之后,该CountDownLatch的生命周期就结束了,它不能重用。那么有没有既可以自己给自己打开门闩而且还能重用的呢,有的,那就是CyclicBarrier,译作回环栅栏。它的使用方法和CountDownLatch差不多,也有一个计数值,叫做parties。CyclicBarrier的使用通俗来说就是,有一个栅栏(CyclicBarrier),它必须由n(parties)个人才能被推到,推到之后这n个人(线程)才能出来,出来之后呢,再把这个栅栏重新立起来又可以用了(重用)。下面从源码的角度分析一下CyclicBarrier的实现原理。

属性

  • lock(ReentrantLock)

    它是用来给CyclicBarrier的操作加锁的

  • trip(Condition)

    用来实现CyclicBarrier的wait和notify的

  • parties(int)

    计数值,相当于CountDownLatch的count

  • barrierCommand(Runnable)

    当CyclicBarrier打开后,要执行的任务

  • generation(Generation)

    CyclicBarrier的代,用来实现CyclicBarrier的重用

  • count(int)

    值等于parties,每次有线程进入时,count值减一,减到0的时候CyclicBarrier打开,随后count值被reset为parties

内部类

静态内部类Generation,该类的对象表示CyclicBarrier的当前代,Generation类有一个属性broken,用来表示当前屏障是否被损坏。

private static class Generation {
    boolean broken = false;
}

构造方法

parties表示在CyclicBarrier被打开之前,需要有parties个线程执行await方法。

barrierAction表示CyclicBarrier被打开的时候需要执行的command,可以为null

public CyclicBarrier(int parties) {
    this(parties, null);
}

public CyclicBarrier(int parties, Runnable barrierAction) {
    if (parties <= 0) throw new IllegalArgumentException();
    this.parties = parties;
    this.count = parties;
    this.barrierCommand = barrierAction;
}

方法

await

await方法调用了私有的dowait方法。

public int await() throws InterruptedException, BrokenBarrierException {
    try {
      	return dowait(false, 0L);
    } catch (TimeoutException toe) {
      	throw new Error(toe); // cannot happen
    }
}
private int dowait(boolean timed, long nanos)
    throws InterruptedException, BrokenBarrierException, TimeoutException {
    final ReentrantLock lock = this.lock;
    // 首先加锁
    lock.lock();
    try {
        final Generation g = generation;
        //判断当前generation是不是broken的
        if (g.broken)
            throw new BrokenBarrierException();
        // 线程被interrupt,则设置当前generation为broken的,唤醒all
        if (Thread.interrupted()) {
            breakBarrier();
            throw new InterruptedException();
        }
        //每进来一个线程,count值减1
        int index = --count;
        if (index == 0) {  // count减到0,也就是CyclicBarrier要打开了
            boolean ranAction = false;
            try {
                final Runnable command = barrierCommand;
                //command不为null,执行command
                if (command != null)
                    command.run();
                ranAction = true;
                //更新generation,以便下次重用
                nextGeneration();
                return 0;
            } finally {
                if (!ranAction)
                    breakBarrier();
            }
        }

        // count值不为0的时候
        for (;;) {
            try {
                //由于timed=false,所以进入await,线程挂起
                if (!timed)
                    trip.await();
                else if (nanos > 0L)
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
                if (g == generation && ! g.broken) {
                    breakBarrier();
                    throw ie;
                } else {
                    // We're about to finish waiting even if we had not
                    // been interrupted, so this interrupt is deemed to
                    // "belong" to subsequent execution.
                    Thread.currentThread().interrupt();
                }
            }
            if (g.broken)
                throw new BrokenBarrierException();
            //如果线程因为换代操作而被唤醒,则返回计数器的值
            if (g != generation)
                return index;
            if (timed && nanos <= 0L) {
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        lock.unlock();
    }
}
//切换CyclicBarrier到下一代
private void nextGeneration() {
    //唤醒条件队列所有线程
    trip.signalAll();
    //设置count值为parties
    count = parties;
    //设置CyclicBarrier的generation为新的generation
    generation = new Generation();
}

//破坏当前栅栏
private void breakBarrier() {
    //将当前CyclicBarrier的broken状态设置为true
    generation.broken = true;
    //设置count值为parties
    count = parties;
    //唤醒所有线程
    trip.signalAll();
}

例子

CyclicBarrier源码中提供的使用示例

class Solver {
    final int N;
    final float[][] data;
    final CyclicBarrier barrier;
 
    class Worker implements Runnable {
      int myRow;
      Worker(int row) { myRow = row; }
      public void run() {
        while (!done()) {
          processRow(myRow);
 
          try {
            barrier.await();
          } catch (InterruptedException ex) {
            return;
          } catch (BrokenBarrierException ex) {
            return;
          }
        }
      }
    }
 
    public Solver(float[][] matrix) {
      data = matrix;
      N = matrix.length;
      Runnable barrierAction =
        new Runnable() { public void run() { mergeRows(...); }};
      barrier = new CyclicBarrier(N, barrierAction);
 
      List<Thread> threads = new ArrayList<Thread>(N);
      for (int i = 0; i < N; i++) {
        Thread thread = new Thread(new Worker(i));
        threads.add(thread);
        thread.start();
      }
 
      // wait until done
      for (Thread thread : threads)
        thread.join();
    }
}