CountDownLatch源码导读

518 阅读5分钟

CountDownLatch是在JDK1.5被引入的,又是并发大神Doug Lea的杰作。 它的作用是让一组线程在等待其他N个线程到达一个统一的状态时再继续执行。

一个典型的应用场景就是:多线程并发执行任务的耗时统计。 任务开始执行前记录下系统时间,然后启动N个线程并发执行任务,主线程直接await()进入阻塞状态,单个线程任务执行完毕调用countDown()递减计数器,所有线程的任务都执行完毕后,主线程被唤醒,计算系统时间差就是任务的整个耗时情况。

CountDownLatch是基于AQS实现的,建议先看AQS的源码:《AQS源码导读》。

源码导读

CountDownLatch源码很简单,代码量不多,读起来还是很轻松的,核心都在AQS里。

属性

CountDownLatch有一个内部类Sync,它继承自AQS,是一个基于AQS共享模式的同步器。 CountDownLatch本身几乎没有什么逻辑处理,全都依赖于内部类Sync

Sync也可以看做是一个共享锁,它使用AQS的state作为计数器,线程能获取到共享锁的前提是state==0。 线程调用countDown()其实就是state减一的过程,当state大于0时,线程调用await()就会入队并Park。当其他线程不断调用countDown()state减至0时,共享资源就算是成功释放了,此时AQS会调用doReleaseShared()唤醒队列中的线程继续执行。

/*
基于AQS实现的共享模式的同步器:
1.使用AQS的state作为计数器。
2.竞争到共享锁的前提:state==0。
3.成功释放锁的标识:state减至0.
 */
private final Sync sync;

构造函数

CountDownLatch只有一个构造函数,创建实例必须指定计数器大小,之后无法修改

// 实例化,指定计数器
public CountDownLatch(int count) {
	// 不能为负数
    if (count < 0) throw new IllegalArgumentException("count < 0");
    // 实际上就是实例化了一个Sync对象
    this.sync = new Sync(count);
}

Sync的构造函数,就是将计数器count赋值给AQS的state

// 将计数器赋值给AQS的state
Sync(int count) {
	setState(count);
}

核心方法

await

线程调用await(),其实是让Sync调用了AQS的acquireSharedInterruptibly()方法:

public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

acquireSharedInterruptibly()是AQS的模板方法,从方法名就能看出来,它是响应中断的。 它首先会调用tryAcquireShared()去尝试获取共享资源,如果获取失败,则调用doAcquireSharedInterruptibly将线程入队并Park。

// 共享模式下,响应中断的方式获取资源
public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    if (Thread.interrupted())
    	// 如果线程发生中断,抛异常
        throw new InterruptedException();
    // 尝试获取共享资源
    if (tryAcquireShared(arg) < 0)
    	// 获取不到则入队并Park
        doAcquireSharedInterruptibly(arg);
}

tryAcquireShared()是子类实现的,逻辑很简单,就是判断state是否等于0:

// 尝试获取共享锁:state==0表示成功获取到。
protected int tryAcquireShared(int acquires) {
	return (getState() == 0) ? 1 : -1;
}

如果state不等于0,则会获取资源失败,AQS会调用doAcquireSharedInterruptibly()

  1. 首先创建一个和当前线程绑定的Node节点,并插入到队尾。
  2. 如果当前节点是头节点,就会在此尝试去竞争资源。
  3. 竞争失败则要将前节点的状态改为SIGNAL,当前线程Park。
  4. 等待其他线程释放资源并唤醒当前线程。
private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    // 创建一个和当前线程绑定的Node节点,并添加到队尾
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;//是否竞争失败
    try {
        for (;;) {
        	// 获取当前节点的前驱节点
            final Node p = node.predecessor();
            // 如果前驱节点是头节点,自己就有资格去竞争了
            if (p == head) {
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                	/*
                	竞争成功,将当前节点设为head。
                	因为是共享模式,如果还有剩余资源可用,需要唤醒后继节点。
                	*/
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            /*
			竞争失败,或自己压根就没资格去竞争,则判断是否需要Park。
			Park的前提条件:将前驱节点的waitStatus设为SIGNAL
			*/
            if (shouldParkAfterFailedAcquire(p, node) &&
            	// 需要Park,则LockSupport.park()挂起线程
                parkAndCheckInterrupt())
                // 如果线程发生中断了,则抛异常
                throw new InterruptedException();
        }
    } finally {
        if (failed)
        	// 竞争失败后取消竞争了,将节点状态设为CANCELLED
            cancelAcquire(node);
    }
}

await()默认会无期限的等待,直到被唤醒或中断,CountDownLatch还提供了一个支持超时等待的重载方法。

// 超时等待,也是响应中断的
public boolean await(long timeout, TimeUnit unit)
		throws InterruptedException {
	return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}

它调用的是AQS的tryAcquireSharedNanos()方法,逻辑和acquireSharedInterruptibly()差不多,也是先尝试获取共享资源,区别是如果获取不到,会调用doAcquireSharedNanos()

// 支持超时的获取共享资源
public finalboolean tryAcquireSharedNanos(int arg, long nanosTimeout)
        throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    return tryAcquireShared(arg) >= 0 ||
    	// 获取失败,入队并Park给定时间
        doAcquireSharedNanos(arg, nanosTimeout);
}

doAcquireSharedNanos()doAcquireSharedInterruptibly()逻辑差不多,区别是doAcquireSharedInterruptibly()调用LockSupport.park()无期限的挂起线程,而doAcquireSharedNanos()则是计算线程需要别挂起的时间,调用LockSupport.parkNanos()挂起,到期OS会自动将线程唤醒。

private boolean doAcquireSharedNanos(int arg, long nanosTimeout)
        throws InterruptedException {
    if (nanosTimeout <= 0L)
    	// 不用挂起
        return false;
    // 计算到期时间
    final long deadline = System.nanoTime() + nanosTimeout;
    // 创建节点并入队
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        for (;;) {
            final Node p = node.predecessor();
            if (p == head) {
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                	// 获取到资源了
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return true;
                }
            }
            // 计算应该被挂起的时间
            nanosTimeout = deadline - System.nanoTime();
            if (nanosTimeout <= 0L)
                return false;
            // 判断竞争失败后是否需要被Park
            if (shouldParkAfterFailedAcquire(p, node) &&
                nanosTimeout > spinForTimeoutThreshold)
                // 需要,则Park给定时间,到期自动唤醒
                LockSupport.parkNanos(this, nanosTimeout);
            if (Thread.interrupted())
                throw new InterruptedException();
        }
    } finally {
        if (failed)
        	// 最终没有获取到资源,则退出竞争
            cancelAcquire(node);
    }
}

countDown

countDown()就是将state减一的过程:

/*
CAS+自旋的方式递减state。
当state减少至0时,AQS将唤醒队列中所有的线程。
 */
public void countDown() {
	sync.releaseShared(1);
}

releaseShared()是AQS的模板方法:

  1. 尝试释放共享资源,这里就是state减一,当state==0代表成功释放。
  2. 如果成功释放,则以共享的方式唤醒队列中的线程。
// 释放共享资源
public finalboolean releaseShared(int arg) {
	// 尝试释放共享资源
    if (tryReleaseShared(arg)) {
    	// 释放成功,以共享的方式唤醒队列中的线程
        doReleaseShared();
        return true;
    }
    return false;
}

tryReleaseShared()是子类实现的,逻辑很简单,就是CAS+自旋的方式让state减一,然后判断是否等于0:

/*
尝试释放锁:state减至0表示成功释放。
成功释放后,AQS会调用doReleaseShared()唤醒队列中的节点。
 */
protected boolean tryReleaseShared(int releases) {
	// CAS+自旋重试的方式,state递减。
	for (;;) {
		int c = getState();
		if (c == 0)
			return false;
		int nextc = c-1;
		if (compareAndSetState(c, nextc))//CAS失败,自旋重试
			return nextc == 0;
	}
}

如果state减至0,AQS会调用doReleaseShared()唤醒队列中的线程:

// 共享模式下唤醒队列中的线程
private void doReleaseShared() {
    for (;;) {
        Node h = head;
        // h==null,说明队列是空的
        // head==tail,说明队列中没有节点等待唤醒
        if (h != null && h != tail) {
            int ws = h.waitStatus;
           	// 节点状态=SIGNAL,说明后继节点在等待被唤醒
            if (ws == Node.SIGNAL) {
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                    continue;            // loop to recheck cases
                // 唤醒后继节点,后继节点唤醒竞争到资源后,会调用setHeadAndPropagate()
                // 如果还有剩余资源可以,会继续唤醒后继节点,因此当前线程没有必要将节点全部唤醒
                unparkSuccessor(h);
            }
            // ???
            else if (ws == 0 &&
                     !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                continue;                // loop on failed CAS
        }
        /*
        如果h!=head,说明唤醒的后继节点已经竞争到资源,并将head指向它了,说明可能还有资源可用,
        后面的节点还有被唤醒的机会,因此自旋重试。
        */
        if (h == head)
            break;
    }
}

getCount

getCount()可以返回当前计数器的值:

/*
获取当前计时器,返回的就是AQS的state变量。
 */
public long getCount() {
	return sync.getCount();
}

Sync的getCount():

// 获取计数器,AQS的state变量。
int getCount() {
	return getState();
}

总结

CountDownLatch是基于AQS共享模式实现的同步工具类,它允许一组线程在其他线程的前置操作全部完成之前,一直阻塞。 它使用AQS的state作为计数器,当计数器大于0时,线程会被入队并挂起,随着其他线程不断调用countDown(),计数器不断递减,state减至0时,CountDownLatch就会开放,AQS会唤醒队列中所有的线程继续执行。