CyclicBarrier源码分析

32 阅读1分钟

属性字段

//锁
private final ReentrantLock lock = new ReentrantLock();
//条件队列
private final Condition trip = lock.newCondition();
//屏障线程数
private final int parties;
//所有线程到达屏障后执行的任务
private final Runnable barrierCommand;
//当前分代
private Generation generation = new Generation();

//每一代的线程数量,在屏障被破坏或者到下一代时候会被重置为parties
private int count;

构造函数

public CyclicBarrier(int parties) {
    this(parties, null);
}
public CyclicBarrier(int parties, Runnable barrierAction) {
    //参数校验
    if (parties <= 0) throw new IllegalArgumentException();
    //触发屏障的线程数
    this.parties = parties;
    //分代线程数
    this.count = parties;
    //回调任务
    this.barrierCommand = barrierAction;
}

核心方法

await()

public int await() throws InterruptedException, BrokenBarrierException {
    try {
        return dowait(false, 0L);
    } catch (TimeoutException toe) {
        throw new Error(toe); // cannot happen
    }
}
private int dowait(boolean timed, long nanos)
    throws InterruptedException, BrokenBarrierException,
           TimeoutException {
    final ReentrantLock lock = this.lock;
    //加锁
    lock.lock();
    try {
        //分代
        final Generation g = generation;
        //屏障被破坏,抛异常
        if (g.broken)
            throw new BrokenBarrierException();
        //线程被中断
        if (Thread.interrupted()) {
            //设置屏障破坏,唤醒线程
            breakBarrier();
            throw new InterruptedException();
        }
        //线程数减1
        int index = --count;
            //所有线程都到达了
        if (index == 0) {  // tripped
            boolean ranAction = false;
            try {
                
                final Runnable command = barrierCommand;
                if (command != null)
                    //执行回调任务
                    command.run();
                ranAction = true;
                //重置屏障
                nextGeneration();
                return 0;
            } finally {
                if (!ranAction)
                    breakBarrier();
            }
        }

        // loop until tripped, broken, interrupted, or timed out
        //不是所有线程都到达,
        for (;;) {
            try {
                //没有设置超时时间
                if (!timed)
                    //等待,直到有线程唤醒
                    trip.await();
                 //有超时时间   
                else if (nanos > 0L)
                    超时等待    
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
                if (g == generation && ! g.broken) {
                    //有中断异常,并且是当前分代被破坏,破环屏障,唤醒线程
                    breakBarrier();
                    throw ie;
                } else {
                    //线程中断
                    Thread.currentThread().interrupt();
                }
            }
            //破环状态抛出异常
            if (g.broken)
                throw new BrokenBarrierException();
            //分代已经更新了,返回索引
            if (g != generation)
                return index;
            //超时
            if (timed && nanos <= 0L) {
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        //锁释放
        lock.unlock();
    }
}