源码赏析:一文读懂java并发编程工具类CountDownLatch和CyclicBarrier的原理

150 阅读8分钟

CountDownLatch可以使一个线程等待多个线程或多个线程等待一个线程到达在某一同步点时进行同步操作。比如在读取解析大文件时,我们可以将文件进行分片,交给多个线程共同解析,每个线程只解析一个文件分片。在线程解析完毕时等待其他线程解析完毕,当所有线程都解析完毕时,主线程才进行结果合并返回给客户端。

Cyclicbarrier则是用于多个线程相互等待

CountDownLatch

我们先来看一下它的整体结构

public class CountDownLatch {
    private static final class Sync extends AbstractQueuedSynchronizer {
    }

    private final Sync sync;
    
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }
    
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    public void countDown() {
        sync.releaseShared(1);
    }
}

我们看到CountDownLatch内部的同步器依然继承了AQS,其中count表示能让多少个线程进行同步。

它的两个核心方法await()countDown()await()表示当前线程等待其他线程都到达同步点,未到达同步点时,当前线程阻塞。countDown()表示某一个线程到达同步点,count值减一,count为0时表示所有线程都达到同步点。

首先看一下await()方法,内部调用了AQS的acquireSharedInterruptibly()方法。看一下AQS的此方法做了什么?

public final void acquireSharedInterruptibly(int arg)
    throws InterruptedException {
    // 判断并清除线程的中断状态,如果当前线程被中断则抛出异常
    if (Thread.interrupted() ||
        (tryAcquireShared(arg) < 0 &&  // 判断state状态,若为false则等待线程可以往下执行,不需要再进行入队操作
         acquire(null, arg, true, true, false, 0L) < 0))  //进行入队操作,若中间被中断会返回负数
        throw new InterruptedException();
}

// CountDownLatch的sync重写了AQS的此方法判断state是否为0,若为0则当前线程不在等待,接着执行
protected int tryAcquireShared(int acquires) {
    // 判断state是否为0,若为0则当前线程不在等待,接着执行
    return (getState() == 0) ? 1 : -1;
}

acquire()方法的执行过程在[上一篇文章]](ReentrantLock与AbstractQueuedSynchronizer源码解析,一文读懂底层原理)已经讲过。这是一个很重要的方法,继承了AQS的相关类都会用到此方法,可以查看上一篇文章进行了解。acquire(null, arg, true, true, false, 0L) < 0传递的参数如上,是可以响应中断的。在入队过程中被中断会跳出循环,执行 cancelAcquire(node, interrupted, interruptible)取消获取资源并返回负数。

现在我们梳理一下await()方法的流程

  1. 调用tryAcquireShared(int acquires)查看state是否为0,若为0则不需要等待,接着往下执行。
  2. 当state不为0时,接着进行入队操作。入队之后进行等待,直到state=0被唤醒,才接着往下执行

接下来我们看一下countDown()方法的底层执行逻辑。从上面我们看到countDown()调用了sync.releaseShared(1);方法。现在来看一该方法的内部情况。

// 此方法是AQS内部的一个方法
public final boolean releaseShared(int arg) {
    // 判断state减少一个值后是否为0,为0表示所有线程都已经到达同步点
    if (tryReleaseShared(arg)) {
        // 唤醒CLH队列中的线程,CLH中的线程为调用了await()方法的线程
        signalNext(head);
        return true;
    }
    // 不是则直接返回,接着执行此线程的任务
    return false;
}

// CountDownLatch内部同步器重写AQS的tryReleaseShared(int releases)
protected boolean tryReleaseShared(int releases) {
    for (;;) {
        // 获取当前state的值
        int c = getState();
        if (c == 0)
            return false;
        // 将state减小1
        int nextc = c - 1;
        // 将state原子替换为减小后的值,成功则返回
        if (compareAndSetState(c, nextc))
            return nextc == 0;
    }
}

// AQS唤醒CLH中的线程
private static void signalNext(Node h) {
    Node s;
    // 当前节点不为null,它的下一个节点也不为null,且下一个节点在等待状态
    if (h != null && (s = h.next) != null && s.status != 0) {
        s.getAndUnsetStatus(WAITING);  // 取消等待状态
        LockSupport.unpark(s.waiter);  // 唤醒线程
    }
}

countDown()方法执行过程:

  1. 首先原子自减state(count)值
  2. 判断当前state是否为0,若为0则表示所有线程都到达同步点,让等待线程开始执行

注:无论自减之后state是否为0,此线程都不会入队,接着往下执行任务。

CountDownLatch中的count是递减的,只有减法操作没有加法,且sync被final修饰。一旦一开始给sync复制后它就不能改变了,所以count只能从初值一直减小到0,然后一直为0.所以说它是一次性的。

Cyclicbarrier

先看一下Cyclicbarrier的整体结构

public class CyclicBarrier {
    private static class Generation {
        Generation() {}                
        boolean broken;              
    }
    
    private final ReentrantLock lock = new ReentrantLock();
    
    private final Condition trip = lock.newCondition();

    private final int parties;

    private final Runnable barrierCommand;

    private Generation generation = new Generation();

    private int count;

    public CyclicBarrier(int parties, Runnable barrierAction) {
        if (parties <= 0) throw new IllegalArgumentException();
        this.parties = parties;
        this.count = parties;
        this.barrierCommand = barrierAction;
    }

    public CyclicBarrier(int parties) {
        this(parties, null);
    }

}

CyclicBarrier中的让线程相互等待的操作有两个方法,一个是await()一直等,还有一个await(long timeout, TimeUnit unit)超时等待。它们内部都是调用了自身的dowait()方法

public int await() throws InterruptedException, BrokenBarrierException {
        try {
            return dowait(false, 0L);
        } catch (TimeoutException toe) {
            throw new Error(toe); // cannot happen
        }
    }

    
    public int await(long timeout, TimeUnit unit)
        throws InterruptedException,
               BrokenBarrierException,
               TimeoutException {
        return dowait(true, unit.toNanos(timeout));
    }

dowait()方法是其核心方法,我们来仔细了解一下其内部原理

private int dowait(boolean timed, long nanos)
    throws InterruptedException, BrokenBarrierException,
           TimeoutException {
    // 获取锁来进行安全的扣减count操作
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        // 获取当前代,CyclicBarrier每重置一次成为一代
        final Generation g = generation;
        // 当前代的屏障是否损坏
        if (g.broken)
            throw new BrokenBarrierException();
        // 查看当前线程是否被中断
        if (Thread.interrupted()) {
            breakBarrier();
            throw new InterruptedException();
        }
        // count进行扣减
        int index = --count;
        if (index == 0) {  // 若为0则表示所有线程都已经到达屏蔽点
            Runnable command = barrierCommand;  // 所有线程到达屏蔽点之后所要执行的一个任务
            if (command != null) {  // 只需执行一次,比如可以打印日志记录所有线程到达屏蔽点的时间
                try {
                    command.run();
                } catch (Throwable ex) {
                    breakBarrier(); 
                    throw ex;
                }
            }
            nextGeneration();  // 重置count值,产生下一代
            return 0;
        }

        //count不为0则进入循环,直到跳闸、断开、中断或超时
        for (;;) {
            try {  // 判断是否是超时阻塞
                if (!timed)
                    trip.await();  // 进入可中断条件等待
                else if (nanos > 0L)  // 查看是否超时
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
                if (g == generation && ! g.broken) {
                    breakBarrier();
                    throw ie;
                } else {
                    Thread.currentThread().interrupt();// 设置中断
                }
            }
            
            // 如果屏障被破坏
            if (g.broken)
                throw new BrokenBarrierException();
            // 如果已进入下一代
            if (g != generation)
                return index;
            // 等待超时,打破屏障,抛出异常
            if (timed && nanos <= 0L) {
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        lock.unlock();
    }
}

dowait()流程为:

  1. 先判断当代屏障是否被破坏或当前线程是否被中断,被破坏或中断直接抛出异常
  2. 将count进行减一操作,并判断扣减后是否为0
  3. 若为0重置下一代,并唤醒所有在屏蔽点等待中的线程继续执行
  4. 若不为0则加入条件等待队列
  5. 若加入过程中抛出异常,则捕获异常。看yclicBarrier是否被重置,没有则破坏屏障抛出异常。否则中断当前线程,返回index。
  6. 若设置为超时等待还会判断是否超时,超时会打破屏障,抛出异常

进入trip.await()看一下它做了什么操作。

// 此方法是AQS中实现的
public final void await() throws InterruptedException {
    // 如果线程被中断,清除中断并返回true
    if (Thread.interrupted())
        // 抛出异常,在上面方法的catch捕获
        throw new InterruptedException();
    // 新建节点
    ConditionNode node = new ConditionNode();
    long savedState = enableWait(node);  // 初始化节点,设置状态并加入条件等待队列,并释放锁
    LockSupport.setCurrentBlocker(this); // 当前线程阻塞在这个对象上
    boolean interrupted = false, cancelled = false, rejected = false;
    while (!canReacquire(node)) { // 判断条件是否已满足,若满足则加入到AQS的CLH队列上
        // 如果检测到中断
        if (interrupted |= Thread.interrupted()) {
            // 检查等待节点是否已被信号通知
            if (cancelled = (node.getAndUnsetStatus(COND) & COND) != 0)
                break; // 中断发生在signal前则退出循环             
        } else if ((node.status & COND) != 0) {
            // 节点仍在条件队列中,执行阻塞
            try {
                if (rejected)
                    node.block();  // 标准阻塞,调用LockSupport.park();
                else
                    ForkJoinPool.managedBlock(node);  // 托管阻塞
            } catch (RejectedExecutionException ex) {
                rejected = true;
            } catch (InterruptedException ie) {
                interrupted = true;
            }
        } else
            Thread.onSpinWait();    // 自旋等待
    }
    // 清除阻塞对象
    LockSupport.setCurrentBlocker(null);
    // 清除状态
    node.clearStatus();
    // 加入AQS的CLH队列
    acquire(node, savedState, false, false, false, 0L);
    // 中断后处理
    if (interrupted) {
        if (cancelled) {
            unlinkCancelledWaiters(node);  // 清理被取消的节点
            throw new InterruptedException();
        }
        Thread.currentThread().interrupt(); // 中断线程
    }
}

以上就是仍然后线程未到达屏障点的执行流程。现在我们来看一下当最后一个线程到达屏蔽点是如何唤醒其他线程的。会执行nextGeneration()方法

private void nextGeneration() {
    // 唤醒条件等待队列的所有线程
    trip.signalAll();
    // 重置count和Generation
    count = parties;
    generation = new Generation();
}

// 用的是AQS中的signalAll()方法
public final void signalAll() {
    // 获取头节点
    ConditionNode first = firstWaiter;
    // 若获取锁的线程不是当前线程抛出异常
    if (!isHeldExclusively())
        throw new IllegalMonitorStateException();
    // 唤醒节点
    if (first != null)
        doSignal(first, true);
}

// 实际唤醒节点的方法
private void doSignal(ConditionNode first, boolean all) {
    // 遍历所有节点
    while (first != null) {
        ConditionNode next = first.nextWaiter;
        // 到达尾部
        if ((firstWaiter = next) == null)
            lastWaiter = null;
        // 节点还在等待中
        if ((first.getAndUnsetStatus(COND) & COND) != 0) {
            // 加入AQS的等待队列
            enqueue(first);
            if (!all)
                break;
        }
        first = next;
    }
}

dowait()方法最后会释放锁,此时唤醒刚刚加入AQS的CLH队列的线程。unlock()方法最终会调用signalNext(Node h)唤醒下一个线程并获取锁。进入条件等待时会释放锁,退出等待时会重新获取锁。只要每次释放锁都会唤醒下一个线程。这样往复唤醒所有线程。释放锁的方法也在上一篇文章讲过了。有兴趣的可以去看看。 以上就是CyclicBarrier进行同步的所有机制了。

感谢您的收看!