聊聊CountDownLatch 源码

1,640 阅读6分钟

一、CountDownLatch是什么?

A synchronization aid that allows one or more threads to wait until a set of operations being performed in other threads completes.

CountDownLatch 允许一个或多个线程等待直到在其他线程中执行的一组操作完成的同步辅助。
CountDownLatch 内部维护了一个计数器,调用 await()阻塞当前线程,每个线程完成自己的操作之后调用countDown() 都会将计数器减一,然后会在计数器的值变为 0 之前一直阻塞,直到计数器的值变为 0,释放所有等待的线程;

二、源码解析

构造函数内部,初始化一个Sync(count)

//java.util.concurrent.CountDownLatch
public CountDownLatch(int count) {
   this.sync = new Sync(count);
}

private static final class Sync extends AbstractQueuedSynchronizer {
    Sync(int count) {
       //AQS中的state值,充当计数器
       setState(count);
    }
}

1.图解AQS框架

82077ccf14127a87b77cefd1ccf562d3253591.png

2.AQS内部类Node属性介绍

0001111.png

3.countDown()方法里面做了什么?

//java.util.concurrent.CountDownLatch
public void countDown() {
  sync.releaseShared(1);
}

//java.util.concurrent.locks.AbstractQueuedSynchronizer
public final boolean releaseShared(int arg) {
    //AQS里面的tryReleaseShared需要子类覆写
    if (tryReleaseShared(arg)) {
       //state为0的时候,去唤醒等待队列中的线程
       doReleaseShared();
       return true;
    }
    return false;
}

//java.util.concurrent.CountDownLatch.Sync
protected boolean tryReleaseShared(int releases) {
     for (;;) {
        int c = getState();
        if (c == 0)
            return false;
        //减少计数
        int nextc = c - 1;
        //CAS修改计数
        if (compareAndSetState(c, nextc))
            //计数为0才会执行doReleaseShared()
            return nextc == 0;
    }
}

我们来看一下doReleaseShared() 共享模式的释放操作——发出后继信号并确保传播下去。

private void doReleaseShared() {
    for (;;) {
      Node h = head;
      //不是尾节点
      if (h != null && h != tail) {
          int ws = h.waitStatus;
          if (ws == Node.SIGNAL) {
               //避免两次unpark
               if (!h.compareAndSetWaitStatus(Node.SIGNAL, 0)){
                   continue; 
               }
               //唤醒头节点的后继节点
               unparkSuccessor(h);
             }else if (ws == 0 && !h.compareAndSetWaitStatus(0, Node.PROPAGATE)){
                //头节点的等待状态为0
                //使用CAS把当前节点状态设置为PROPAGATE,确保后面可以传递下去
                continue; 
             } 
          }
          //头结点无变更,退出循环
          if (h == head)
             break;
     }
}

4.await()做了什么?

await有两个方法: 一个是:await()
另一个是:await(long timeout, TimeUnit unit) 超时没有执行完的任务,会调用:cancelAcquire

我们看看await(),另一个方法留给大家自己学习一下:

//java.util.concurrent.CountDownLatch
public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

//java.util.concurrent.locks.AbstractQueuedSynchronizer
public final void acquireSharedInterruptibly(int arg) throws InterruptedException {
    if (Thread.interrupted())//获取并且清空线程中断标记位
       //如果是中断状态则直接抛InterruptedException异常
       throw new InterruptedException();
    //只有小于0的时候才会加入同步等待队列
    if (tryAcquireShared(arg) < 0)
       doAcquireSharedInterruptibly(arg);
}

private Node addWaiter(Node mode) {
    Node node = new Node(mode);
    for (;;) {
       Node oldTail = tail;
       //尾节点不为空说明已经初始化过了
       if (oldTail != null) {
          //Unsafe.putObject(Object o, int offset, Object x)
          //设置node的前驱节点为oldTail
          U.putObject(node, Node.PREV, oldTail);
          if (compareAndSetTail(oldTail, node)) {
              //oldTail的后继节点设置为node
              oldTail.next = node;
              return node;
           }
        } else {
           //初始化同步队列,头尾节点都是指向同一个新的Node实例
           initializeSyncQueue();
        }
    }
}

private void doAcquireSharedInterruptibly(int arg) throws InterruptedException {
   //创建一个共享模式的节点,添加到队列中 
   final Node node = addWaiter(Node.SHARED);
   try {
      for (;;) {
        //返回当前节点的前驱节点
        final Node p = node.predecessor();
        if (p == head) {
          //返回1不再阻塞,出队,-1仍然继续阻塞
          int r = tryAcquireShared(arg);
          if (r >= 0) {
             //往下翻一下,有分析
             setHeadAndPropagate(node, r);
             p.next = null; 
             return;
           }
         }
         //往下翻一下,有分析
         if (shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt()){
            throw new InterruptedException();        
          }      
       }
   } catch (Throwable t) {
       //往下翻一下,有分析
       cancelAcquire(node);
       throw t;
   }
}

/**
* 设置同步等待队列的头节点,判断当前处理的节点的后继节点是否共享模式的节点,
* 如果共享模式的节点,propagate大于0或者节点的waitStatus为PROPAGATE
* 则进行共享模式下的释放资源
*/
private void setHeadAndPropagate(Node node, int propagate) {
    Node h = head; 
    //设置node为头节点
    setHead(node);
    //propagate大于0 || 头节点为null || 头节点的状态为非取消 || 再次获取头节点为null || 再次获取头节点的状态为非取消
    if (propagate > 0 || h == null || h.waitStatus < 0 || (h = head) == null || h.waitStatus < 0) {
        Node s = node.next;
        //后继节点==null或者是共享模式的节点
        if (s == null || s.isShared())
            doReleaseShared();//往上翻,上面分析过了
     }
}

private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
        int ws = pred.waitStatus;
        if (ws == Node.SIGNAL)
            //前驱节点状态设置成Node.SIGNAL成功,等待被release调用释放,后继节点可以安全地进入阻塞状态
            return true;
        if (ws > 0) {
            do {
                node.prev = pred = pred.prev;
                //waitStatus大于0,表示前驱节点已经取消
            } while (pred.waitStatus > 0);
            //找到一个非取消的节点,重新通过next引用连接当前共享模式的节点
            pred.next = node;
        } else {
            //前驱节点非取消状态,全部设置为Node.SIGNAL
            pred.compareAndSetWaitStatus(ws, Node.SIGNAL);
        }
        return false;
    }
    
    // 阻塞当前线程,获取并且重置线程的中断标记位
    private final boolean parkAndCheckInterrupt() {
        //来了来了,关键的方法:阻塞线程的实现,依赖Unsafe的API
        LockSupport.park(this);
        return Thread.interrupted();
    }

我们再看一下cancelAcquire(node)里面做了什么:

// java.util.concurrent.locks.AbstractQueuedSynchronizer

private void cancelAcquire(Node node) {
        if (node == null)
            return;
        //此时节点的线程已经中断取消,置空节点的线程
        node.thread = null;
        Node pred = node.prev;
        //(跳过取消状态的节点)获取当前节点的上一个非取消状态的节点
        while (pred.waitStatus > 0)
            node.prev = pred = pred.prev;
           
        //保存node.prev非取消状态节点的后继节点
        Node predNext = pred.next;
        //更新当前节点状态=取消
        node.waitStatus = Node.CANCELLED;
        // 如果当前节点是尾节点,将当前节点的上一个非取消状态的节点设置为尾节点
        // 更新失败的话,则进入else,如果更新成功,将tail的后继节点设置为null
        if (node == tail && compareAndSetTail(node, pred)) {
            pred.compareAndSetNext(predNext, null);
        } else {
            int ws;
            // 如果当前节点不是head的后继节点
            // 1:判断当前节点前驱节点的是否为SIGNAL,
            // 2:如果不是,则把前驱节点设置为SINGAL看是否成功
            // 如果1和2中有一个为true,再判断当前节点的线程是否为null
            // 如果上述条件都满足,把当前节点的前驱节点的后继指针指向当前节点的后继节点
            if (pred != head &&
                ((ws = pred.waitStatus) == Node.SIGNAL ||
                 (ws <= 0 && pred.compareAndSetWaitStatus(ws, Node.SIGNAL))) &&
                pred.thread != null) {
                Node next = node.next;
                if (next != null && next.waitStatus <= 0)
                    pred.compareAndSetNext(predNext, next);
            } else {
                //上述条件不满足,唤醒当前节点的后继节点
                unparkSuccessor(node);
            }
            node.next = node;// help GC
        }
    }

private void unparkSuccessor(Node node) {
	// 获取头节点waitStatus
	int ws = node.waitStatus;
	if (ws < 0)
            compareAndSetWaitStatus(node, ws, 0);
	// 获取当前节点的下一个节点
	Node s = node.next;
	// 如果下个节点是null或者下个节点被cancelled,就找到队列最开始的非cancelled的节点
	if (s == null || s.waitStatus > 0) {
            s = null;
            // 就从尾部节点开始找,到队首,找到队列第一个waitStatus<0的节点。
            for (Node t = tail; t != null && t != node; t = t.prev)
                 if (t.waitStatus <= 0)
                     s = t;
	}
	// 如果当前节点的下个节点不为空,而且状态<=0,就把当前节点unpark
	if (s != null)
            LockSupport.unpark(s.thread);
}

cancelAcquire() 调用的地方:

1.主动中断
2.acquire过程中发生异常
3.超时版本的API调用的时候剩余超时时间小于等于零的时候

cancelAcquire() 主要作用是把取消的节点移出同步等待队列,满足上面代码里面分析的条件,会进行后继节点的唤醒unparkSuccessor(node)

5.聊聊LockSupport如何实现阻塞和解除阻塞的?

WX20210807-134633@2x.png 点击查看 JDK11 LockSupport文档地址

先看看下面几个代码片段:

//示例一:
Log.d(TAG,"001")
LockSupport.park(this)
Log.d(TAG,"002")
-------------------
输出: 
001 
阻塞中.......

//示例二:
LockSupport.unpark(Thread.currentThread())
Log.d(TAG,"001")
LockSupport.park(this)
Log.d(TAG,"002")
....
Log.d(TAG,"执行完")
-------------------
输出:
001 
002
执行完

//示例三:
val thread = Thread.currentThread()
cacheThreadPool.execute{
    Log.d(TAG,"一个耗时的异步任务,正在执行...")
    Thread.sleep(1500)
    //提供许可,解除阻塞
    LockSupport.unpark(thread)
}
Log.d(TAG,"001")
//阻塞当前线程
LockSupport.park(this)
Log.d(TAG,"002")
....
Log.d(TAG,"执行完")
-------------------
输出:
001 
一个耗时的异步任务,正在执行...
002
执行完

看了上面的示例,是不是懂了?🙈🙈
LockSupportpark方法有两个:

//java.util.concurrent.locks.LockSupport
public static void park(Object blocker) {
  Thread t = Thread.currentThread();
  setBlocker(t, blocker);
  U.park(false, 0L);
  setBlocker(t, null);
}
public static void park() {
   U.park(false, 0L);
}

/**
* 通过反射机制获取Thread类的parkBlocker字段对象。
* 然后通过sun.misc.Unsafe对象的objectFieldOffset方法,
* 获取到parkBlocker在内存里的偏移量
*/
private static void setBlocker(Thread t, Object arg) {
  U.putObject(t, PARKBLOCKER, arg);
}

static {
    PARKBLOCKER = U.objectFieldOffset
                (Thread.class.getDeclaredField("parkBlocker"));
}

用谁?还用疑问🤔吗?当然推荐大家使用有参数的park(blocker)方法啦。

我们看一下点击查看Thread源码里面的parkBlocker:

/** 
 * The argument supplied to the current call to 
 * java.util.concurrent.locks.LockSupport.park. 
 * Set by (private) java.util.concurrent.locks.LockSupport.setBlocker 
 * Accessed using java.util.concurrent.locks.LockSupport.getBlocker 
 */  
volatile Object parkBlocker;

parkBlocker对象是用来记录线程被阻塞是被谁阻塞的,用于线程监控和分析工具来定位原因的。
LockSupport通过getBlocker获取到阻塞的对象,主要用于监控和分析线程。


park阻塞、unpark解除阻塞,最终会调用UnSafe内部对应的native方法的实现
点击查看UnSafe源码

三、使用场景

1.ARouter加载指定包名下class集合的用法

//com.alibaba.android.arouter.utils.ClassUtils
fun getFileNameByPackageName():Set<String>{
     val classNames: Set<String> = HashSet()
    val paths = getSourcePaths(context)
    val parserCtl = CountDownLatch(paths.size())
    for (path in paths) {
       DefaultPoolExecutor.getInstance().execute{
          try{
            //耗时的加载数据....
          }finnaly{
              parserCtl.countDown()
          }
       }
    }
    parserCtl.await()
    return classNames
}

2.假设有三个线程:A/B/C,我们在A/B完成或者部分完成的时候启动C

class TaskThread(private val taskName:String, private val countDownLatch: CountDownLatch, private val testSleep:Long): Thread() {
    override fun run() {
       Log.d(TAG,"『开始』执行${taskName}任务,来自:${currentThread().name}")
       sleep(testSleep)
       Log.d(TAG,"【结束】执行${taskName}任务,来自:${currentThread().name}")
       countDownLatch.countDown()
    }
}

用法如下:
val countDownLatch = CountDownLatch(2)
val taskA = TaskThread("A",countDownLatch,100)
val taskB = TaskThread("B",countDownLatch,500)
taskA.start()
taskB.start()
countDownLatch.await()
Thread {
   Log.d(TAG, "任务C搞个小业务,来自:${Thread.currentThread().name}")
}.start()
-------------------
输出:
『开始』执行A任务,来自:Thread-2
『开始』执行B任务,来自:Thread-3
【结束】执行A任务,来自:Thread-2
【结束】执行B任务,来自:Thread-3
任务C搞个小业务,来自:Thread-4