三十九、并发工具之CountDownLatch

122 阅读4分钟

CountDownLatch

概述

CountDownLatch是并发安全的计数器,常用于2个及以上业务可以并发执行,并且需要知道执行完成才能进行下一步操作的场合。

CountDownLatch可以设置一个初始值,这个值就是并发执行任务的数量。每有一个任务执行完成后,调用CountDown方法,初始值就会减1,当初始值减到0时,说明任务都已经完成。

CountDownLatch用法

背景栏:

假设有三个任务需要并发执行,执行完成后才能进行下一步操作。

板书栏:

public class Test {

    private static ThreadPoolExecutor pool = (ThreadPoolExecutor) Executors.newFixedThreadPool(3);

    public static void main(String[] args) throws InterruptedException {
        CountDownLatch latch = new CountDownLatch(3);
        System.out.println("开始执行主业务任务");
        pool.execute(new A(latch));
        pool.execute(new B(latch));
        pool.execute(new C(latch));
        // latch.await();
        if (latch.await(1, TimeUnit.SECONDS)) {
            System.out.println("任务执行完成");
        } else {
            System.out.println("任务没有执行完成");
        }
        System.out.println("结束执行主业务任务");
        pool.shutdown();
    }

    static class A implements Runnable{

        private CountDownLatch latch;

        A (CountDownLatch latch) {
            this.latch = latch;
        }

        @Override
        public void run() {
            System.out.println("开始执行a任务");
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println("结束执行a任务");
            latch.countDown();
        }
    }

    static class B implements Runnable {

        private CountDownLatch latch;

        B (CountDownLatch latch) {
            this.latch = latch;
        }

        @Override
        public void run() {
            System.out.println("开始执行b任务");
            try {
                Thread.sleep(1500);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println("结束执行b任务");
            latch.countDown();
        }
    }

    static class C implements Runnable {

        private CountDownLatch latch;

        C (CountDownLatch latch) {
            this.latch = latch;
        }

        @Override
        public void run() {
            System.out.println("开始执行c任务");
            try {
                Thread.sleep(2000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println("结束执行c任务");
            latch.countDown();
        }
    }
}

要点栏:

要点1:

CountDownLatch是不可以被复用的,所以一般CountDownLatch作为局部变量声明出来使用。

要点2:

await方法有两种,一种是无参的,会一直等待任务执行完成;另一种是带有时间参数的,如果在时间内完成任务,返回true,否则返回false。如果任务花费时间比await方法等待时间短,await方法也会提前结束。

要点3:

countDown方法是不会将CountDownLatch的数值减到负数的,方法内部做了健壮性判断。

分析栏:

有参的await方法主要是用于防止网络波动情况下,如果任务中有网络io操作,可能任务一直结束不了,await方法给了一个最后时间。

CountDownLatch的源码

基本属性

板书栏:

public class CountDownLatch {
   
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
			// 将数据赋值给AQS中的state属性
            setState(count);
        }

        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

    private final Sync sync;

	// 构造函数
    public CountDownLatch(int count) {
		// count小于0,直接抛出异常
        if (count < 0) throw new IllegalArgumentException("count < 0");
		// 声明Sync对象
        this.sync = new Sync(count);
    }
}

要点栏:

要点1:

CountDownLatch是基于AQS实现的计数器,数据赋值给了AQS中的state属性。

要点2:

CountDownLatch有一个内部类Sync,其重写了AQS中的tryAcquireShared和tryReleaseShared方法。

基本方法

await方法

板书栏:

public void await() throws InterruptedException {
	sync.acquireSharedInterruptibly(1);
}

public final void acquireSharedInterruptibly(int arg)
		throws InterruptedException {
	if (Thread.interrupted())
		// 线程中断,抛中断异常
		throw new InterruptedException();
	// 任务完成,tryAcquireShared返回1
	// 任务未完成,tryAcquireShared返回-1
	if (tryAcquireShared(arg) < 0)
		// 当前线程封装成Node,添加到AQS的双向链表中
		// 并且当前线程挂起
		doAcquireSharedInterruptibly(arg);
}

protected int tryAcquireShared(int acquires) {
	// state值为0,返回1,否则返回-1
	return (getState() == 0) ? 1 : -1;
}

private void doAcquireSharedInterruptibly(int arg)
	throws InterruptedException {
	// 当前线程封装成Node,添加到AQS的双向链表中
	// 以共享锁的方式
	final Node node = addWaiter(Node.SHARED);
	boolean failed = true;
	try {
		for (;;) {
			// 当前线程的前一个节点
			final Node p = node.predecessor();
			if (p == head) {
				// state值为0,返回1,否则返回-1
				int r = tryAcquireShared(arg);
				if (r >= 0) {
					setHeadAndPropagate(node, r);
					p.next = null; 
					failed = false;
					return;
				}
			}
			// 当前线程挂起
			if (shouldParkAfterFailedAcquire(p, node) &&
				parkAndCheckInterrupt())
				throw new InterruptedException();
		}
	} finally {
		if (failed)
			cancelAcquire(node);
	}
}

countDown方法

板书栏:

public void countDown() {
	sync.releaseShared(1);
}

public final boolean releaseShared(int arg) {
	// state减1后,判断state是否为0
	// 如果state为0,tryReleaseShared返回true,否则返回false
	if (tryReleaseShared(arg)) {
		// state为0,说明任务完成
		// 唤醒挂起的线程
		doReleaseShared();
		return true;
	}
	return false;
}

protected boolean tryReleaseShared(int releases) {
	for (;;) {
		int c = getState();
		if (c == 0)
			return false;
		int nextc = c-1;
		if (compareAndSetState(c, nextc))
			return nextc == 0;
	}
}

private void doReleaseShared() {
	for (;;) {
		Node h = head;
		if (h != null && h != tail) {
			int ws = h.waitStatus;
			// 判断头节点的状态是否为-1,如果是,说明后面有未唤醒的节点
			if (ws == Node.SIGNAL) {
				// 通过CAS将head节点状态从-1,改成0
				if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
					continue;  
				// 唤醒head节点后面的正常节点
				unparkSuccessor(h);
			}
			// 这里是为了处理semaphore的BUG
			else if (ws == 0 &&
					 !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
				continue;    
		}
		// 这里是为了处理semaphore的BUG
		if (h == head)        
			break;
	}
}