JUC - CountDownLatch

183 阅读2分钟

CountDownLatch,同步工具类,由 Doug Lea 大神在 jdk1.5 中引入,允许一个或多个线程一直等待直到其他线程运行完成后再执行。

CountDown(倒计时),Latch(门闩),倒计时到了(state=0),门闩才打开。

摘录源码中的一段注释:

A synchronization aid that allows one or more threads to wait until a set of operations being performed in other threads completes.

常用方法


  1. CountDownLatch(int count):构造方法,创建一个值为count 的计数器。
  2. await():将当前线程加入阻塞队列,阻塞当前线程。
  3. await(long timeout, TimeUnit unit):在 timeout 单位时间内阻塞当前线程。
  4. countDown():计数器递减1,减至0时,当前线程唤醒阻塞队列里的所有线程。

示例


public class CountDownLatchDemo {
    public static void main(String[] args) {
        int num = 5;
        ExecutorService executor = Executors.newFixedThreadPool(num);
        // 初始化 state = 5
        CountDownLatch latch = new CountDownLatch(num);

        CountDownLatchDemo latchDemo = new CountDownLatchDemo();
        for (int i = 0; i < num; i++) {
            latchDemo.sign(i + 1, executor, latch);
        }

        try {
            // state != 0 时,门闩关闭,阻塞
            lattch.await();
            System.out.println("人员到齐,比赛开始...");
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    private void sign(int i, ExecutorService executor, CountDownLatch latch) {
        executor.submit(() -> {
            try {
                TimeUnit.SECONDS.sleep(i);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }

            System.out.println(i + "号参赛选手已到位");
            // 倒计时,state-1
            latch.countDown();
        });
    }
}

执行结果:

1号参赛选手已到位
2号参赛选手已到位
3号参赛选手已到位
4号参赛选手已到位
5号参赛选手已到位
人员到齐,比赛开始...

好了,下面我们来看一下它的庐山真面目。

类图


基于同步队列 AQS,java.util.concurrent 包中多处都能看到它的身影。

CountDownLatch-uml.png

内部类


/**
 * Synchronization control For CountDownLatch.
 * Uses AQS state to represent count.
 */
private static final class Sync extends AbstractQueuedSynchronizer {
    private static final long serialVersionUID = 4982264981922014374L;

    Sync(int count) {
        setState(count);
    }

    int getCount() {
        return getState();
    }

    // 重写 AQS 的 tryAcquireShared方法
    protected int tryAcquireShared(int acquires) {
        return (getState() == 0) ? 1 : -1;
    }

    // 重写 AQS 的 tryReleaseShared方法
    protected boolean tryReleaseShared(int releases) {
        // Decrement count; signal when transition to zero
        for (;;) {
            int c = getState();
            if (c == 0)
                return false;
            int nextc = c - 1;
            if (compareAndSetState(c, nextc))
                return nextc == 0;
        }
    }
}

成员变量


Sync 继承于 AbstractQueuedSynchronizer。

private final Sync sync;

构造方法


创建同步队列,并设置初始计数器值,值赋给 AQS 的 state 变量,CountDownLatch 的核心就是控制该变量。

public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}

countDown()


countDown() 倒数计数。内部类 sync 调用父类 AQS 的模版方法 releaseShared(int arg) 释放资源。

public void countDown() {
    sync.releaseShared(1);
}

releaseShared() 方法中首先调用重写父类的 tryReleaseShared 方法,判断资源释放是否成功(state-1=0),成功则唤醒阻塞的所有线程。

public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}

tryReleaseShared() 递减计数,当 state 到 0 时唤醒所有阻塞线程。

protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    for (;;) {
        int c = getState();
        if (c == 0)
            // 获取 state 的值如果 == 0 则直接返回 false
            return false;
        int nextc = c - 1;
        // state - 1 后 CAS 原子更新
        if (compareAndSetState(c, nextc))
            // 更新成功并且 nextc == 0,则返回true执行doReleaseShared方法唤醒阻塞线程。
            return nextc == 0;
    }
}

doReleaseShared() 唤醒所有阻塞队列里面的线程

private void doReleaseShared() {
    for (;;) {
        Node h = head;
        if (h != null && h != tail) {
            int ws = h.waitStatus;
            if (ws == Node.SIGNAL) {
                if (!h.compareAndSetWaitStatus(Node.SIGNAL, 0))
                    continue;            // loop to recheck cases
                unparkSuccessor(h);
            }
            else if (ws == 0 &&
                     !h.compareAndSetWaitStatus(0, Node.PROPAGATE))
                continue;                // loop on failed CAS
        }
        if (h == head)                   // loop if head changed
            break;
    }
}

await()


使当前线程在 CountDownLatch 计数减至 0 之前一直等待(构建节点加入阻塞队列,挂起当前线程),除非线程被中断。

public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

计数器不为 0 则把当前调用线程加入阻塞队列(支持中断)。

public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}

tryAcquireShared(arg) 返回 1(state == 0)则不会阻塞。通过调用 AQS 的 getState() 方法获取 state 的值。

protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
}

// AQS
protected final int getState() {
    return state;
}

tryAcquireShared(arg) 返回 -1(state != 0) 进入doAcquireSharedInterruptibly方法,构建阻塞队列双向链表,挂起当前线程。

private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    final Node node = addWaiter(Node.SHARED);
    try {
        for (;;) {
            final Node p = node.predecessor();
            if (p == head) {
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    return;
                }
            }
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt())
                throw new InterruptedException();
        }
    } catch (Throwable t) {
        cancelAcquire(node);
        throw t;
    }
}