CyclicBarrier源码导读

611 阅读4分钟

CyclicBarrier也被称为“栅栏”、“屏障”,从名字就看的出来,它是可以被循环使用的。 CyclicBarrier的作用是当一组线程全部都到达一个状态时,再全部同时执行。

例如,在对接口做并发请求测试的时候,创建N个线程并启动,在请求接口前可能有一些准备工作要做,此时就可以调用CyclicBarrier.await(),提前准备完毕的线程会阻塞,待所有线程都准备完毕后,在一瞬间将请求全部发出。

CyclicBarrier没有直接依赖于AQS,它是基于ReentrantLock的Condition来实现的。 它内部维护了一个变量count来记录剩余等待到达的线程数,只要count大于0,线程调用await()就会被加入到Condition队列中并Park挂起。当线程全部到达,count等于0时,CyclicBarrier会调用condition.signalAll()唤醒所有的等待线程。

源码导读

CyclicBarrier的代码量并不多,主要的逻辑都在ReentrantLock和AQS里,强烈建议先看ReentrantLock和AQS的源码:《ReentrantLock源码导读》、《AQS源码导读》。

属性

CyclicBarrier内部通过ReentrantLock来保证同步,通过Condition来实现线程的阻塞和唤醒。 parties用来记录开启屏障所需的线程数量,屏障开启后,会利用parties复位。 barrierCommand用来保存屏障开启时优先执行的任务,可以为空,如果任务执行异常了,屏障会被标记为破坏的。 count用来记录屏障开启所需的剩余线程数,开启后会利用parties复位。 CyclicBarrier是可以复用的,每次屏障开启后就会更新换代一次,generation代表当前一代。

// 同步操作依赖于ReentrantLock
private final ReentrantLock lock = new ReentrantLock();
//
private final Condition trip = lock.newCondition();
// 屏障开启需要到达的线程数量
private final int parties;
// 屏障开启时,优先执行的任务
private final Runnable barrierCommand;

/**
 * 屏障开启前,还需要等待到达的线程数。
 * 每次await()就会--count,count==0时屏障开启
 */
private int count;

// 当前周期,CyclicBarrier是可以复用的,屏障开启后会调用nextGeneration()复位
private Generation generation = new Generation();

/**
 * CyclicBarrier是可以复用的,屏障开启后会调用nextGeneration()复位
 */
private static class Generation {
    /*
    表示屏障是否被破坏:
    1.等待超时了
    2.线程被中断了
    3.barrierAction执行发生异常
    4.手动reset()
     */
    boolean broken = false;
}

构造函数

CyclicBarrier提供了两个构造函数,可以指定屏障开启所需的线程数,和屏障开启时优先执行的任务。

/*
parties:屏障开启所需的线程数
barrierAction:屏障开启优先执行的任务
 */
 */
public CyclicBarrier(int parties, Runnable barrierAction) {
    if (parties <= 0) throw new IllegalArgumentException();
    this.parties = parties;
    this.count = parties;
    this.barrierCommand = barrierAction;
}

/*
仅仅指定线程数
 */
public CyclicBarrier(int parties) {
    this(parties, null);
}

核心方法

await

线程调用await()方法后,count就会自减1,如果count仍大于0,线程就会入队阻塞。 CyclicBarrier提供了两个await()方法,都是响应中断的,一个支持超时,一个不支持超时:

  • await() 不超时的等待,直到屏障开启或线程中断。
  • await(long timeout, TimeUnit unit) 超时等待,直到屏障开启、超时、或线程中断。
// 非超时的等待
public int await() throws InterruptedException, BrokenBarrierException {
    try {
        return dowait(false, 0L);
    } catch (TimeoutException toe) {
        throw new Error(toe); // 不可能发生,因为这是非超时等待
    }
}

两种方法其实调用的都是dowait()dowait()是支持超时的。

  1. dowait()首先是加锁,然后判断屏障是否被破坏,如果被破坏了就不会阻塞线程,而是抛出异常。
  2. 接着判断线程是否被中断,因为CyclicBarrier是响应中断的,如果发生中断也不会阻塞,而是抛异常,并将屏障标记为破坏的,同时唤醒所有的等待线程。
  3. 以上都没发生,则count自减1,如果count大于0,则判断是否是超时等待,如果是超时等待则调用trip.awaitNanos()将线程挂起指定时间,否则trip.await()线程被无期限挂起。
  4. 如果count等于0,则说明线程全部到达,屏障开启,调用nextGeneration()唤醒所有线程,并更新换代。
/**
 * 等待屏障开启
 * @param timed 是否超时
 * @param nanos 超时时间
 */
private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
        TimeoutException {
    final ReentrantLock lock = this.lock;
    lock.lock();//加锁,保证同步
    try {
        final Generation g = generation;

        if (g.broken)
            // 如果屏障被破坏,抛异常
            throw new BrokenBarrierException();

        if (Thread.interrupted()) {
            // 如果线程被中断,则屏障破坏,count复位,唤醒所有线程
            breakBarrier();
            throw new InterruptedException();
        }

        int index = --count;
        if (index == 0) {  // 线程全部到达,屏障开启
            boolean ranAction = false;// barrierCommand有没有正常执行的标记
            try {
                final Runnable command = barrierCommand;
                if (command != null)
                    // 如果构造函数指定了barrierCommand,则屏障开启时优先执行
                    command.run();
                ranAction = true;
                // 状态复位,唤醒所有的线程,开启下一轮屏障
                nextGeneration();
                return 0;
            } finally {
                if (!ranAction)
                    // barrierCommand执行异常,屏障被破坏
                    breakBarrier();
            }
        }

        /*
        屏障尚未开启,线程阻塞,直到:
        1.屏障开启
        2.线程被中断
        3.等待超时
         */
        for (;;) {
            try {
                if (!timed)//没有开启超时,则无限期阻塞
                    trip.await();
                else if (nanos > 0L)//超时时间大于0,则LockSupport.parkNanos()
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
                if (g == generation && ! g.broken) {
                    breakBarrier();
                    throw ie;
                } else {
                    // We're about to finish waiting even if we had not
                    // been interrupted, so this interrupt is deemed to
                    // "belong" to subsequent execution.
                    Thread.currentThread().interrupt();
                }
            }

            if (g.broken)
                // 屏障被破坏
                throw new BrokenBarrierException();

            if (g != generation)
                // generation已经被复位
                return index;

            if (timed && nanos <= 0L) {
                // 等待超时了,屏障破坏,唤醒所有线程
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        lock.unlock();
    }
}

reset

将屏障重置为初始状态,它会调用breakBarrier()将屏障标为被破坏的,同时唤醒所有的等待线程,后面进来的线程就会抛BrokenBarrierException异常。之后调用nextGeneration()将屏障更新换代。

/**
 * 重置屏障
 */
public void reset() {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        // 当前屏障被打破,唤醒所有等待线程
        breakBarrier();
        // 开启新一轮的屏障
        nextGeneration();
    } finally {
        lock.unlock();
    }
}

getNumberWaiting

获取当前在等待屏障开启的线程数:

/**
 * 返回当前屏障的等待线程数
 */
public int getNumberWaiting() {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
    	// 总数 - 剩余等待数
        return parties - count;
    } finally {
        lock.unlock();
    }
}

其他方法

CyclicBarrier的核心方法并不多,主要是await(),其他的代码都很简单:

getParties() 获取屏障开启所需的线程数:

/**
 * 返回屏障开启所需的线程数
 */
public int getParties() {
    return parties;
}

isBroken() 当前屏障是否被破坏:

/**
 * 当前屏障是否被破坏
 */
public boolean isBroken() {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        return generation.broken;
    } finally {
        lock.unlock();
    }
}

总结

CyclicBarrier的代码阅读起来还是很简单的,主要依赖于ReentrantLock,ReentrantLock又依赖于AQS,因此阅读源码的时候建议从底层开始读起。 它的主要作用就是当一组线程全部都到达一个状态时,再全部同时执行,一般用来做多线程并发测试比较多。


你可能感兴趣的文章: