简单理解 CountDownLatch

128 阅读4分钟

CountDownLatch通过AQS的state控制多线程之间的协助

当state为0的时候,await()阻塞的线程被唤醒 countDown的作用是为了使得AQS的state - 1

demo:

package com.xxx.service.lock;

import java.util.concurrent.CountDownLatch;

/**
 * Created by IntelliJ IDEA.
 * Date: 2020-05-27
 * Time: 11:14
 */
public class CountDownLatchDemo {

    public static void main(String [] args) {
        CountDownLatch start = new CountDownLatch(1);
        CountDownLatch runnerFinish = new CountDownLatch(2);

        for (int i = 0; i < 2; i++) {
            Runner runner = new Runner(start, runnerFinish);
            Thread thread = new Thread(runner);
            thread.start();
        }
        try {
            // 等待所有的thread线程都执行到start.await
            Thread.sleep(1000);
            // 唤醒被start.await阻塞的thread子线程,让所有thread线程都执行继续的后续的逻辑
            start.countDown();
            // 阻塞主线程,等待thread线程全部执行完countDown,然后唤醒主线程,结束流程
            runnerFinish.await();
            System.out.println("结束");

        } catch (Exception e) {

        }

    }

    static class Runner implements Runnable {

        // 游戏开始
        private CountDownLatch start;

        // 选手完成
        private CountDownLatch runnerFinish;

        Runner (CountDownLatch start, CountDownLatch runnerFinish) {
            this.start = start;
            this.runnerFinish = runnerFinish;
        }

        @Override
        public void run() {

            try {
                System.out.println("线程:" + Thread.currentThread().getName() + "预备");
                // 阻塞,当state=0后,被唤醒
                start.await();
                System.out.println("线程:" + Thread.currentThread().getName() + "开始跑");
                // 扣减
                runnerFinish.countDown();
            } catch (Exception e) {

            }
        }
    }
}



源码分析:
CountDownLatch提供的
public void countDown() {
    sync.releaseShared(1);
}


调用的是AQS提供的共享锁释放
public final boolean releaseShared(int arg) {
    // tryReleaseShared 是模板方法,需要子类实现
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}


获取锁的逻辑由子类实现:tryReleaseShared(arg)
protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    // 自旋
    for (;;) {
        // 获取当前state的值
        int c = getState();
        // 如果已经为0了,那么不能再减1
        if (c == 0)
            return false;
        int nextc = c-1;
        // 通过CAS设置新的state值,并且判断是否为0
        // 为0表示对阻塞队列的线程唤醒
        if (compareAndSetState(c, nextc))
            return nextc == 0;
    }
}


阻塞队列的唤醒由AQS提供
private void doReleaseShared() {
    // 自旋唤醒所有的有效节点(头结点不是,头结点是虚节点)
    for (;;) {
        Node h = head;
        // 判断头结点是否为空,并且头结点指针和尾结点指针不是同一个
        // 如果此时没有线程调用过CountDownLatch的await方法的话,队列是空的
        if (h != null && h != tail) {
            int ws = h.waitStatus;
            // 头结点的状态是否符合被唤醒 -1
            if (ws == Node.SIGNAL) {
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                    continue;            // loop to recheck cases
                // 唤醒队列的节点(只是唤醒一个)
                unparkSuccessor(h);
            }
            else if (ws == 0 &&
                     !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                continue;                // loop on failed CAS
        }
        if (h == head)                   // loop if head changed
            break;
    }
}


CountDownLatch的await方法,阻塞队列,等待countDown()方法的唤醒。
由子类实现
public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

内部acquireSharedInterruptibly(1)方法由AQS提供,可中断共享锁获取。
public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
        // tryAcquireShared是一个模板方法,需要子类实现
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}


tryAcquireShared方法由子类实现。
如果AQS的state还没有为0,则表示,线程需要被阻塞。
protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
}


doAcquireSharedInterruptibly(arg)方法由AQS提供实现,可中断获取共享锁。
private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    // 新增共享节点
    // addWaiter内部如果同步队列是空的 会创建一个虚节点作为头结点 new Node()
    // 在这周情况下node的predecessor就是head
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        // 自旋
        for (;;) {
            // 获取头结点 
            final Node p = node.predecessor();
            if (p == head) {
                // 尝试再次获取一次锁
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            // shouldParkAfterFailedAcquire 内部逻辑是对于同步队列中的waitStatus状态的判断,已经设置为-1
            if (shouldParkAfterFailedAcquire(p, node) &&
                // parkAndCheckInterrupt 内部逻辑是对线程进行阻塞,LockSupport.park(this); 
                // 并且判断线程是否被中断,返回中断标志Thread.interrupted();改方法会重置线程的中断状态
                parkAndCheckInterrupt())
                // 如果线程被中断了,直接抛出异常
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            // 抛出异常后的线程节点,需要移除出同步队列
            cancelAcquire(node);
    }
}

总结一下就是:

CountDownLatch通过初始化一个AQS的state

假设线程A调用了await(),线程A将会被阻塞

那么线程A需要等待其他的线程调用countDown,将state减为0,然后线程A会被唤醒执行逻辑。

举个栗子: (具体可以看上面的demo)
5个人一起在起跑线等待裁判的起跑信号。

5人为CountDownLatch runner = new CountDownLatch(5)

信号 CountDownLatch start = new CountDownLatch(1)

首先信号await(),等待各个5个线程预备好。

然后信号countDown,开始起跑。

紧接着runner.await()等待runner

(执行5次 countDown)全部跑到终点

runner.await()被唤醒,比赛结束。