高并发编程-CountDownLatch深入解析

109 阅读13分钟
原文链接: mp.weixin.qq.com

若文中代码格式阅读困难,可点击文末"阅读原文"链接友好阅读。

要点解说

CountDownLatch允许一个或者多个线程一直等待,直到一组其它操作执行完成。在使用CountDownLatch时,需要指定一个整数值,此值是线程将要等待的操作数。当某个线程为了要执行这些操作而等待时,需要调用await方法。await方法让线程进入休眠状态直到所有等待的操作完成为止。当等待的某个操作执行完成,它使用countDown方法来减少CountDownLatch类的内部计数器。当内部计数器递减为0时,CountDownLatch会唤醒所有调用await方法而休眠的线程们。

实例演示

下面代码演示了CountDownLatch简单使用。演示的场景是5位运动员参加跑步比赛,发令枪打响后,5个计时器开始分别计时,直到所有运动员都到达终点。

  1. public class CountDownLatchDemo {

  2.    public static void main(String[] args) {

  3.        Timer timer = new Timer(5);

  4.        new Thread(timer).start();

  5.        for (int athleteNo = 0; athleteNo < 5; athleteNo++) {

  6.            new Thread(new Athlete(timer, "athlete" + athleteNo)).start();

  7.        }

  8.    }

  9. }

  10. class Timer implements Runnable {

  11.    CountDownLatch timerController;

  12.    public Timer(int numOfAthlete) {

  13.        this.timerController = new CountDownLatch(numOfAthlete);

  14.    }    

  15.    public void recordResult(String athleteName) {

  16.        System.out.println(athleteName + " has arrived");

  17.        timerController.countDown();

  18.        System.out.println("There are " + timerController.getCount() + " athletes did not reach the end");

  19.    }    

  20.    @Override

  21.    public void run() {

  22.        try {

  23.            System.out.println("Start...");

  24.            timerController.await();

  25.            System.out.println("All the athletes have arrived");

  26.        } catch (InterruptedException e) {

  27.            e.printStackTrace();

  28.        }

  29.    }

  30. }

  31. class Athlete implements Runnable {

  32.    Timer timer;

  33.    String athleteName;    

  34.    public Athlete(Timer timer, String athleteName) {

  35.        this.timer = timer;

  36.        this.athleteName = athleteName;

  37.    }    

  38.    @Override

  39.    public void run() {

  40.        try {

  41.            System.out.println(athleteName + " start running");

  42.            long duration = (long) (Math.random() * 10);

  43.            Thread.sleep(duration * 1000);

  44.            timer.recordResult(athleteName);

  45.        } catch (InterruptedException e) {

  46.            e.printStackTrace();

  47.        }

  48.    }

  49. }

输出结果如下所示:

  1.  Start...

  2. athlete0 start running

  3. athlete1 start running

  4. athlete2 start running

  5. athlete3 start running

  6. athlete4 start running

  7. athlete0 has arrived

  8. There are 4 athletes did not reach the end

  9. athlete3 has arrived

  10. There are 3 athletes did not reach the end

  11. athlete2 has arrived

  12. athlete1 has arrived

  13. There are 1 athletes did not reach the end

  14. There are 2 athletes did not reach the end

  15. athlete4 has arrived

  16. There are 0 athletes did not reach the end

  17. All the athletes have arrived

方法解析

1.构造方法 CountDownLatch(int count)构造一个指定计数的CountDownLatch,count为线程将要等待的操作数。

2.await() 调用await方法后,使当前线程在锁存器(内部计数器)倒计数至零之前一直等待,进入休眠状态,除非线程被中断。如果当前计数递减为零,则此方法立即返回,继续执行。

3.await(long timeout, TimeUnit unit) 调用await方法后,使当前线程在锁存器(内部计数器)倒计数至零之前一直等待,进入休眠状态,除非线程被 中断或超出了指定的等待时间。如果当前计数为零,则此方法立刻返回true值。

3.acountDown() acountDown方法递减锁存器的计数,如果计数到达零,则释放所有等待的线程。如果当前计数大于零,则将计数减少。如果新的计数为零,出于线程调度目的,将重新启用所有的等待线程。

4.getCount() 调用此方法后,返回当前计数,即还未完成的操作数,此方法通常用于调试和测试。

源码解析

进入源码分析之前先看一下CountDownLatch的类图,

Sync是CountDownLatch的一个内部类,它继承了AbstractQueuedSynchronizer。

CountDownLatch(int count)、await()和countDown()三个方法是CountDownLatch的核心方法,本篇将深入分析这三个方法的具体实现原理。

1.CountDownLatch(int count)

  1.       public CountDownLatch(int count) {

  2.        if (count < 0) throw new IllegalArgumentException("count < 0");

  3.        this.sync = new Sync(count);

  4.    }

该构造方法根据给定count参数构造一个CountDownLatch,内部创建了一个Sync实例。Sync是CountDownLatch的一个内部类,其构造方法代码如下:

  1.       Sync(int count) {

  2.        setState(count);

  3.    }

setState方法继承自AQS,给Sync实例的state属性赋值。

  1.       protected final void setState(int newState) {

  2.        state = newState;

  3.    }

这个state就是CountDownLatch的内部计数器。

2.await() 当await()方法被调用时,当前线程会阻塞,直到内部计数器的值等于零或当前线程被中断,下面深入代码分析。

  1.       public void await() throws InterruptedException {

  2.        sync.acquireSharedInterruptibly(1);

  3.    }    

  4.    public final void acquireSharedInterruptibly(int arg)

  5.            throws InterruptedException {

  6.        //如果当前线程中断,则抛出InterruptedException

  7.        if (Thread.interrupted())

  8.            throw new InterruptedException();            

  9.        //尝试获取共享锁,如果可以获取到锁直接返回;

  10.        //如果获取不到锁,执行doAcquireSharedInterruptibly

  11.        if (tryAcquireShared(arg) < 0)

  12.            doAcquireSharedInterruptibly(arg);

  13.    }    

  14.    //如果当前内部计数器等于零返回1,否则返回-1;

  15.    //内部计数器等于零表示可以获取共享锁,否则不可以;

  16.    protected int tryAcquireShared(int acquires) {

  17.        return (getState() == 0) ? 1 : -1;

  18.    }    

  19.    //返回内部计数器当前值

  20.    protected final int getState() {

  21.        return state;

  22.    }    

  23.    //该方法使当前线程一直等待,直到当前线程获取到共享锁或被中断才返回

  24.    private void doAcquireSharedInterruptibly(int arg)

  25.        throws InterruptedException {

  26.        //根据当前线程创建一个共享模式的Node节点

  27.        //并把这个节点添加到等待队列的尾部

  28.        //AQS等待队列不熟悉的可以查看AQS深入解析的内容

  29.        final Node node = addWaiter(Node.SHARED);

  30.        boolean failed = true;

  31.        try {

  32.            for (;;) {

  33.                //获取新建节点的前驱节点

  34.                final Node p = node.predecessor();

  35.                //如果前驱节点是头结点

  36.                if (p == head) {

  37.                    //尝试获取共享锁

  38.                    int r = tryAcquireShared(arg);

  39.                    //获取到共享锁

  40.                    if (r >= 0) {

  41.                        //将前驱节点从等待队列中释放

  42.                        //同时使用LockSupport.unpark方法唤醒前驱节点的后继节点中的线程

  43.                        setHeadAndPropagate(node, r);

  44.                        p.next = null; // help GC

  45.                        failed = false;

  46.                        return;

  47.                    }

  48.                }

  49.                //当前节点的前驱节点不是头结点,或不可以获取到锁

  50.                //shouldParkAfterFailedAcquire方法检查当前节点在获取锁失败后是否要被阻塞

  51.                //如果shouldParkAfterFailedAcquire方法执行结果是当前节点线程需要被阻塞,则执行parkAndCheckInterrupt方法阻塞当前线程

  52.                if (shouldParkAfterFailedAcquire(p, node) &&

  53.                    parkAndCheckInterrupt())

  54.                    throw new InterruptedException();

  55.            }

  56.        } finally {

  57.            if (failed)

  58.                cancelAcquire(node);

  59.        }

  60.    }    

  61.    private Node addWaiter(Node mode) {

  62.        //根据当前线程创建一个共享模式的Node节点

  63.        Node node = new Node(Thread.currentThread(), mode);

  64.        // Try the fast path of enq; backup to full enq on failure

  65.        Node pred = tail;

  66.        //如果尾节点不为空(等待队列不为空),则新节点的前驱节点指向这个尾节点

  67.        //同时尾节点指向新节点

  68.        if (pred != null) {

  69.            node.prev = pred;

  70.            if (compareAndSetTail(pred, node)) {

  71.                pred.next = node;

  72.                return node;

  73.            }

  74.        }

  75.        //如果尾节点为空(等待队列是空的)

  76.        //执行enq方法将节点插入到等待队列尾部

  77.        enq(node);

  78.        return node;

  79.    }    

  80.    //这里如果不熟悉的可以查看AQS深入解析的内容

  81.    Node(Thread thread, Node mode) { // Used by addWaiter

  82.        this.nextWaiter = mode;

  83.        this.thread = thread;

  84.    }    

  85.    private Node enq(final Node node) {

  86.        //使用循环插入尾节点,确保成功插入

  87.        for (;;) {

  88.            Node t = tail;

  89.            //尾节点为空(等待队列是空的)

  90.            //新建节点并设置为头结点

  91.            if (t == null) { // Must initialize

  92.                if (compareAndSetHead(new Node()))

  93.                    tail = head;

  94.            } else {

  95.                //否则,将节点插入到等待队列尾部

  96.                node.prev = t;

  97.                if (compareAndSetTail(t, node)) {

  98.                    t.next = node;

  99.                    return t;

  100.                }

  101.            }

  102.        }

  103.    }    

  104.    //获取当前节点的前驱节点

  105.    final Node predecessor() throws NullPointerException {

  106.        Node p = prev;

  107.        if (p == null)

  108.            throw new NullPointerException();

  109.        else

  110.            return p;

  111.    }    

  112.    //判断当前节点里的线程是否需要被阻塞

  113.    private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {

  114.        //前驱节点线程的状态

  115.        int ws = pred.waitStatus;

  116.        //如果前驱节点线程的状态是SIGNAL,返回true,需要阻塞线程

  117.        if (ws == Node.SIGNAL)

  118.            return true;

  119.        //如果前驱节点线程的状态是CANCELLED,则设置当前节点的前去节点为"原前驱节点的前驱节点"

  120.        //因为当前节点的前驱节点线程已经被取消了

  121.        if (ws > 0) {

  122.            do {

  123.                node.prev = pred = pred.prev;

  124.            } while (pred.waitStatus > 0);

  125.            pred.next = node;

  126.        } else {

  127.            //其它状态的都设置前驱节点为SIGNAL状态

  128.            compareAndSetWaitStatus(pred, ws, Node.SIGNAL);

  129.        }

  130.        return false;

  131.    }    

  132.    //通过使用LockSupport.park阻塞当前线程

  133.    //同时返回当前线程是否中断

  134.    private final boolean parkAndCheckInterrupt() {

  135.        LockSupport.park(this);

  136.        return Thread.interrupted();

  137.    }

3.countDown() 内部计数器减一,如果计数达到零,唤醒所有等待的线程。

  1.       public void countDown() {

  2.        sync.releaseShared(1);

  3.    }    

  4.    public final boolean releaseShared(int arg) {

  5.        //如果内部计数器状态值递减后等于零

  6.        if (tryReleaseShared(arg)) {

  7.            //唤醒等待队列节点中的线程

  8.            doReleaseShared();

  9.            return true;

  10.        }

  11.        return false;

  12.    }    

  13.    //尝试释放共享锁,即将内部计数器值减一

  14.    protected boolean tryReleaseShared(int releases) {

  15.        for (;;) {

  16.            //获取内部计数器状态值

  17.            int c = getState();

  18.            if (c == 0)

  19.                return false;

  20.            //计数器减一

  21.            int nextc = c-1;

  22.            //使用CAS修改state值

  23.            if (compareAndSetState(c, nextc))

  24.                return nextc == 0;

  25.        }

  26.    }    

  27.    private void doReleaseShared() {

  28.        for (;;) {

  29.            //从头结点开始

  30.            Node h = head;

  31.            //头结点不为空,并且不是尾节点

  32.            if (h != null && h != tail) {

  33.                int ws = h.waitStatus;

  34.                if (ws == Node.SIGNAL) {

  35.                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))

  36.                        continue;

  37.                    //唤醒阻塞的线程

  38.                    unparkSuccessor(h);

  39.                }

  40.                else if (ws == 0 &&

  41.                        !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))

  42.                    continue;

  43.            }

  44.            if (h == head)

  45.                break;

  46.        }

  47.    }    

  48.    private void unparkSuccessor(Node node) {

  49.        int ws = node.waitStatus;

  50.        if (ws < 0)

  51.            compareAndSetWaitStatus(node, ws, 0);

  52.        Node s = node.next;

  53.        if (s == null || s.waitStatus > 0) {

  54.            s = null;

  55.            for (Node t = tail; t != null && t != node; t = t.prev)

  56.                if (t.waitStatus <= 0)

  57.                    s = t;

  58.        }

  59.        if (s != null)

  60.            //通过使用LockSupport.unpark唤醒线程

  61.            LockSupport.unpark(s.thread);

  62.    }

原理总结

使用CountDownLatch(int count)构建CountDownLatch实例,将count参数赋值给内部计数器state,调用await()方法阻塞当前线程,并将当前线程封装加入到等待队列中,直到state等于零或当前线程被中断;调用countDown()方法使state值减一,如果state等于零则唤醒等待队列中的线程。

实战经验

实际工作中,CountDownLatch适用于如下使用场景: 客户端的一个同步请求查询用户的风险等级,服务端收到请求后会请求多个子系统获取数据,然后使用风险评估规则模型进行风险评估。如果使用单线程去完成这些操作,这个同步请求超时的可能性会很大,因为服务端请求多个子系统是依次排队的,请求子系统获取数据的时间是线性累加的。此时可以使用CountDownLatch,让多个线程并发请求多个子系统,当获取到多个子系统数据之后,再进行风险评估,这样请求子系统获取数据的时间就等于最耗时的那个请求的时间,可以大大减少处理时间。

面试考点

CountDownLatch和CyclicBarrier的异同?

相同点:都可以实现线程间的等待。 不同点: 1.侧重点不同,CountDownLatch一般用于一个线程等待一组其它线程;而CyclicBarrier一般是一组线程间的相互等待至某同步点; 2.CyclicBarrier的计数器是可以重用的,而CountDownLatch不可以。


 如果觉得有收获,记得关注、点赞、转发