JUC并发编程04——CountDownLatch源码分析

129 阅读5分钟

nick-fewings-DdIGZjmQzdA-unsplash.jpg

CountDownLatch是一个同步工具类,在JDK1.5被引入的,跟它一起被引入的同步工具类还有CyclicBarrierSemaphore等,它们都存在于java.util.concurrent包下。CountDownLatch的功能是能够使一个线程(或多个线程)等待其他线程完成各自的任务后再执行。例如,主线程等待多个子线程全部执行完后主线程再继续执行。

CountDownLatch是通过一个计数器来实现的,计数器的初始值为线程的数量。每当一个线程完成了自己的任务后,计数器的值就会减1。当计数器值到达0时,它表示所有的线程已经完成了任务,然后在闭锁上等待的线程就可以恢复执行任务。

说明:源码析之前要先具备以下两个前置知识点:

  1. LockSupport的用法和原理
  2. AQS的原理

接下来就通过一个demo来探究CountDownLatch的内部原理

其实这些同步工具类底层都是AQS,掌握了AQS再来看这些就相对简单了

/**
 * @author qiuguan
 */
public class CountDownLatchDemo {

    public static void main(String[] args) {

        /**
         * 1.创建 CountDownLatch 对象,并指定参与计数的线程的个数
         */
        final CountDownLatch countDownLatch = new CountDownLatch(10);
        for (int i = 0; i < 10; i++) {
            new Thread(() -> {

                try {
                    doBiz();
                } finally {
                    /**
                     * 2.初始容量是10
                     * 每个线程调用这里会将初始值减1
                     */
                    countDownLatch.countDown();
                }

            }, "t-" + i).start();
        }

        try {
            System.out.println("主线程等待所有子线程任务结束.....waiting");
            /**
             * 3.主线程挂起到这里
             * 当计数器的值变成0,将继续往下执行
             */
            countDownLatch.await();
            System.out.println("所有子线程任务结束.....done");
        } catch (InterruptedException e) {
            e.printStackTrace();
        }

    }

    private static void doBiz() {
        try {
            try {
                //模拟业务处理
                Thread.sleep(300L);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println(Thread.currentThread().getName() + " 子线程开始运行....");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
}

我们看下效果: image.png

接下来我们就看下源码的逻辑

1.CountDownLatch构造器

public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}

然后我们再看下Sync这个内部类:

//继承了AQS类
private static final class Sync extends AbstractQueuedSynchronizer {

    Sync(int count) {
        //构造器传递的参数设置给了AQS类的同步状态位state
        setState(count);
    }

    int getCount() {
        return getState();
    }

    //以共享的方式获取state值
    protected int tryAcquireShared(int acquires) {
        return (getState() == 0) ? 1 : -1;
    }

    //以共享的方式释放state的值
    protected boolean tryReleaseShared(int releases) {
        // Decrement count; signal when transition to zero
        for (;;) {
            int c = getState();
            if (c == 0)
                return false;
            int nextc = c-1;
            if (compareAndSetState(c, nextc))
                return nextc == 0;
        }
    }
}

构造器相对简单,接下来我们看下他的两个核心方法:await()countDown()

2.CountDownLatch#await()

public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    //1.以共享的方式尝试获取同步状态值,就是看state的值是否为0    
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}

2.1 tryAcquireShared(arg)

protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
}

就是看state的值是否为0,为0的话返回1,否则返回-1

因为子线程还在执行任务,并没有countdown到0,所以返回-1,将执行doAcquireSharedInterruptibly(arg)方法

2.2 doAcquireSharedInterruptibly(arg)

private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    //1.创建一个共享的Node节点并入队
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        for (;;) {
            //获取前驱节点
            final Node p = node.predecessor();
            //如果前驱节点是head,说明head的下一个节点就是即将被唤醒去抢锁的节点
            if (p == head) {
                //这个方法上面看到了,就是在挂起之前再次去尝试检查state的值,看是否为0
                //=====》这里我们假设子线程依然没有将state值减为0 《======
                int r = tryAcquireShared(arg);
                //如果为0,说明所有子线程已经将state的初始值10countDown到0了
                if (r >= 0) {
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            //如果前驱节点不是head,或者检查state的值依然小于0(也就是state的值没有减到0)
            //shouldParkAfterFailedAcquire(p, node)修改前驱节点的状态值为SIGNAL,这样它就具备了唤醒下一个节点的能力
            //parkAndCheckInterrupt() 当前线程调用LockSupport.park()方法挂起
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt())
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

这个方法我们再分析AQS源码时有看到这个方法,细节上可能稍有不同,但整体逻辑是一模一样的。

那么最终主线程就是这样:

image.png

到这里,主线程就挂起了,直到有人将它唤醒继续往下执行,那么接下来我们就看下是如何唤醒主线程的

3.CountDownLatch#countDown()

每个线程调用countDown()方法,最终将调用SyncreleaseShared(1)方法:

//参数值是1
public final boolean releaseShared(int arg) {
    //tryReleaseShared(arg) 方法就是使state值减1
    if (tryReleaseShared(arg)) {
        //如果state减成0,则调用这个方法
        doReleaseShared();
        return true;
    }
    return false;
}

3.1 tryReleaseShared(arg)

protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    for (;;) {
        //获取state的值
        int c = getState();
        if (c == 0) {
            //这里什么时候会进来?
            //比如state设置为10,但是线程有11个
            return false;
        }    
        //使state的值每次减少1    
        int nextc = c-1;
        //CAS设置减1后的值
        if (compareAndSetState(c, nextc))
            //判断是否为0
            return nextc == 0;
    }
}

假设state值从10减少为0了,接下来将执行唤醒主线程的动作

3.2 doReleaseShared()

private void doReleaseShared() {
    //死循环搭配CAS做自旋
    for (;;) {
        Node h = head;
        //判断条件队列中是否有有效的节点
        if (h != null && h != tail) {
            int ws = h.waitStatus;
            //节点的状态在主线程挂起前,已经在shouldParkAfterFailedAcquire(p, node)方法中修改成了 SIGNAL
            if (ws == Node.SIGNAL) {
                //CAS再次修改为0
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                    continue;            // loop to recheck cases
                    
                //唤醒后继节点  
                unparkSuccessor(h);
            }
            else if (ws == 0 &&
                     !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                continue;                // loop on failed CAS
        }
        
        //唤醒之后,来到这里,他们就不相等了,为什么?
        //因为唤醒之后,挂起的主线程将继续往下执行,然后将自己设置新的head节点
        //这一块请看主线程LockSupport.park()之后的逻辑
        //然后继续循环,h != null && h != tail 条件不成立,退出了doReleaseShared()
        if (h == head)             
            break;
    }
}

3.3 唤醒后继节点(主线程)

//节点Node是头节点
private void unparkSuccessor(Node node) {
    
    //状态值是0,前面通过CAS修改过
    int ws = node.waitStatus;
    if (ws < 0)
        compareAndSetWaitStatus(node, ws, 0);

    /*
     * 找到头节点的下一个节点(因为队列是FIF0,所以谁先入队谁先被唤醒)
     */
    Node s = node.next;
    //这里是检查排队的线程是否被取消了,如果取消了就从同步队列的尾部往前找,直到找到一个不是取消的节点
    // waitStatus = 1 表示被取消了
    if (s == null || s.waitStatus > 0) {
        s = null;
        for (Node t = tail; t != null && t != node; t = t.prev)
            if (t.waitStatus <= 0)
                s = t;
    }

    //我这里的s就是主线程,所以这里不为空,然后唤醒主线程。
    if (s != null)
        LockSupport.unpark(s.thread);
}

3.4 主线程被唤醒

image.png

主线程被唤醒后,将继续往下执行,也就是执行for循环。首先获取前驱节点并判断是否为头节点,如果是,则去检查state的值是否为0,如果为0,说明所有子线程的任务都结束了,然后重新设置新的头节点和传播行为。

image.png

3.5 设置新的头节点和传播行为

private void setHeadAndPropagate(Node node, int propagate) {
    Node h = head; // Record old head for check below
    //将当前的主线程节点设置为新的头节点(傀儡节点),线程数据都被清空
    setHead(node);
    /*
     * propagate 值为1,前面的tryAcquireShared(arg)方法检查state是否为0,如果
     * 为0,则返回1,否则返回-1
     */
    if (propagate > 0 || h == null || h.waitStatus < 0 ||
        (h = head) == null || h.waitStatus < 0) {
        //获取下一个节点
        Node s = node.next;
        //如果是null说明没有下一个节点,再去doReleaseShared()目的是为了double check
        //如果还有下一个节点,并且是共享的节点,则去唤醒,前面我们也说了,CountDownLatch适合一个线程
        //或者多个线程去等待其他线程执行完任务后在执行。所以多个线程等待也是适用的。
        if (s == null || s.isShared())
            //这个方法前面已经看过了,就是唤醒节点
            doReleaseShared();
    }
}

4.CountDownLatch的使用场景

CountDownLatch是一个基于AQS提供的倒计时同步类,主要的适用场景有:

  1. 某一线程在开始运行前等待N个线程执行完毕。就是我们上面举的例子,主线程要等所有的子线程执行完后再去工作。
  2. 多个线程等待某一个线程的信号,同时开始执行(类比于。这个和上面有点相反,我们通过代码来演示下。
public class CountDownLatchDemo2 {

    public static void main(String[] args) {
        final CountDownLatch countDownLatch = new CountDownLatch(1);

        System.out.println("运动员开始准备.........");
        for (int i = 1; i <= 5; i++) {
            final int number = i;
            new Thread(() -> {
                System.out.println(number + "号运行员准备好了...");
                try {
                    countDownLatch.await();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }

                System.out.println(number + "号运行员开始起跑...");
            }).start();
        }

        try {
            Thread.sleep(3000L);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println("裁判吹哨,比赛开始.........");
        countDownLatch.countDown();
    }
}

好了,关于CountDownLatch就介绍到这里吧,欢迎补充和指正!!!

限于作者水平,文中难免有错误之处,欢迎指正,勿喷,感谢感谢