CountDownLatch源码分析

132 阅读2分钟

使用场景

模拟3个人跑步,等3个人都到达终点后,才能执行某些操作。

public static void main(String[] args) throws Exception{
    // 模拟3个人跑步,等3个人都到达终点后,才能执行某些操作
    CountDownLatch countDownLatch = new CountDownLatch(3);
    
    Thread th1 = new Thread(new Runnable() {			
    	@Override
    	public void run() {
    		System.out.println("begin running...");
    		// 模拟跑步
    		System.out.println("end running...");
    		countDownLatch.countDown();
    	}
    });
    th1.start();
    
    // 这里再创建2个线程,模拟2个人跑步
    ...
    
    // 如果线程1先到达终点,先阻塞,等所有线程都到达终点,才会被唤醒
    countDownLatch.await();
    
    // 3个人都到达终点,执行某些操作
    doSomething();		
}

如何控制多个线程,等待其他线程都执行完毕后再执行?

  • 初始化

    CountDownLatch countDownLatch = new CountDownLatch(3);

    初始化时设置state为3,即当前线程持有锁且重入了3次。

    public CountDownLatch(int count) {
    	if (count < 0) throw new IllegalArgumentException("count < 0");
    	this.sync = new Sync(count);
    }
    
    Sync(int count) {
    	setState(count);
    }
    
  • 挂起线程

    countDownLatch.await();

    会先判断线程是否被中断,如果中断就抛出异常。线程1、线程2、线程3执行到这里都会加入到阻塞队列,并挂起线程,具体看AbstractQueuedSynchronizer#doAcquireSharedInterruptibly()方法。

    public void await() throws InterruptedException {
    	sync.acquireSharedInterruptibly(1);
    }
    public final void acquireSharedInterruptibly(int arg) throws InterruptedException {
    	// 是否被中断,中断就抛出异常
        if (Thread.interrupted())
            throw new InterruptedException();
        if (tryAcquireShared(arg) < 0)
            doAcquireSharedInterruptibly(arg);
    }
    private void doAcquireSharedInterruptibly(int arg) throws InterruptedException {
    	// 将线程包装成node,添加到队列
    	final Node node = addWaiter(Node.SHARED);
    	boolean failed = true;
    	try {
    	    for (;;) {
    	        final Node p = node.predecessor();
    	        if (p == head) {
    	            int r = tryAcquireShared(arg);
    	            // 初始化后的state=3,所以r=-1,下面的条件不满足,
    	            // 当所有线程执行完countDown方法后,state=0,r=1,条件满足
    	            if (r >= 0) {
    		            // 重写设置头结点,并依次唤醒后续线程
    	                setHeadAndPropagate(node, r);
    	                p.next = null; // help GC
    	                failed = false;
    	                return;
    	            }
    	        }
    	        // 挂起线程,结点唤醒后从这里开始执行
    	        if (shouldParkAfterFailedAcquire(p, node) &&
    	            parkAndCheckInterrupt())
    	            throw new InterruptedException();
    	    }
    	} finally {
    	    if (failed)
    	        cancelAcquire(node);
    	}
    }
    
  • 唤醒操作

    countDownLatch.countDown();

    每次执行countDown时,state减1,当state为0时,将执行AbstractQueuedSynchronizer#doReleaseShared()方法唤醒线程,先唤醒头结点,再依次唤醒等待队列中的后续结点。

    public final boolean releaseShared(int arg) {
    	// 返回true,才会唤醒挂起的线程
    	if (tryReleaseShared(arg)) {
    	    doReleaseShared();
    	    return true;
    	}
    	return false;
    }
    // 假设线程1最先进来,state=3,使用cas设置state=2并返回false
    // 接着线程2进来,state=2,使用cas设置state=1并返回false
    // 接着线程3进来,state=1,使用cas设置state=0并返回true
    protected boolean tryReleaseShared(int releases) {
        for (;;) {
            int c = getState();
            if (c == 0)
                return false;
            int nextc = c-1;
            if (compareAndSetState(c, nextc))
                return nextc == 0;
        }
    }
    private void doReleaseShared() {   
    	for (;;) {
    	    Node h = head;
    	    // 先唤醒头结点
    	    if (h != null && h != tail) {
    	        int ws = h.waitStatus;
    	        if (ws == Node.SIGNAL) {
    	            if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
    	                continue;            
    	            unparkSuccessor(h);
    	        }
    	        else if (ws == 0 &&
    	                 !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
    	            continue;                
    	    }
    	    if (h == head)                   
    	        break;
    	}
    }