CountDownLatch源码分析

57 阅读3分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第12天,点击查看活动详情

简介

CountDownLatch是一个同步工具类,典型的用法是将一个程序分为n个互相独立的任务,同时创建一个为n的计数器,每当一个任务完成时,都会调用countDown,等待问题被解决的任务调用这个锁存器的await,将他们自己拦住,直至锁存器计数结束。

CountDownLatch内部类

CountDownLatch中有一个Sync的内部类,该类继承了AQS

private static final class Sync extends AbstractQueuedSynchronizer {
    private static final long serialVersionUID = 4982264981922014374L;

    Sync(int count) {
        setState(count);
    }
    //返回当前计数
    int getCount() {
        return getState();
    }
    // 试图在共享模式下获取对象状态
    protected int tryAcquireShared(int acquires) {
        return (getState() == 0) ? 1 : -1;
    }
    //尝试设置状态来反映共享模式下的一个释放
    protected boolean tryReleaseShared(int releases) {
        // Decrement count; signal when transition to zero
        for (;;) {
            int c = getState();
            if (c == 0)
                return false;
            int nextc = c-1;
            if (compareAndSetState(c, nextc))
                return nextc == 0;
        }
    }
}

构造函数

public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}

初始化的时候会传入一个初始化参数,这个参数会设置到AQS的state属性上

核心方法 -- await()

此方法会在计数器为0之前,使当前线程一直处于等待状态,除非线程中断

public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

由源码可以看出此方法最终调用的方法在AQS中的acquireSharedInterruptibly

public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}

而在acquireSharedInterruptibly方法中先调用tryAcquireShared判断state是否为0,如果小于0则走doAcquireSharedInterruptibly方法

private void doAcquireSharedInterruptibly(int arg) throws InterruptedException {
    // 添加节点至等待队列
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        for (;;) { // 无限循环
            // 获取node的前驱节点
            final Node p = node.predecessor();
            if (p == head) { // 前驱节点为头节点
                // 试图在共享模式下获取对象状态
                int r = tryAcquireShared(arg);
                if (r >= 0) { // 获取成功
                    // 设置头节点并进行繁殖
                    setHeadAndPropagate(node, r);
                    // 设置节点next域
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt()) // 在获取失败后是否需要禁止线程并且进行中断检查
                // 抛出异常
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

调用到这里后还有可能会调用tryAcquireShared和setHeadAndPropagate

setHeadAndPropagate方法

private void setHeadAndPropagate(Node node, int propagate) {
    // 获取头节点
    Node h = head; // Record old head for check below
    // 设置头节点
    setHead(node);
    // 进行判断
    if (propagate > 0 || h == null || h.waitStatus < 0 ||
        (h = head) == null || h.waitStatus < 0) {
        // 获取节点的后继
        Node s = node.next;
        if (s == null || s.isShared()) // 后继为空或者为共享模式
            // 以共享模式进行释放
            doReleaseShared();
    }
}

该方法设置头节点并且释放头节点后面的满足条件的结点,该方法中可能会调用到AQS的doReleaseShared方法

private void doReleaseShared() {
    for (;;) {
        // 保存头节点
        Node h = head;
        if (h != null && h != tail) { // 头节点不为空并且头节点不为尾结点
            // 获取头节点的等待状态
            int ws = h.waitStatus; 
            if (ws == Node.SIGNAL) { // 状态为SIGNAL
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) // 不成功就继续
                    continue;            // loop to recheck cases
                // 释放后继结点
                unparkSuccessor(h);
            }
            else if (ws == 0 &&
                        !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) // 状态为0并且不成功,继续
                continue;                // loop on failed CAS
        }
        if (h == head) // 若头节点改变,继续循环  
            break;
    }
}

到此为止整个await方法的调用的调用链就已经很清楚了,接下来看另一个核心方法

核心函数 -- countDown()

此函数将递减锁存器的计数,如果计数到达零,则释放所有等待的线程

public void countDown() {
    sync.releaseShared(1);
}

调用内部类Sync的releaseShared方法

public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}

在releaseShared中又调用的tryReleaseShared 和 doReleaseShared()方法,这个tryReleaseShared会试图设置状态来反映共享模式下的一个释放

protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    for (;;) {
        int c = getState();
        if (c == 0)
            return false;
        int nextc = c-1;
        if (compareAndSetState(c, nextc))
            return nextc == 0;
    }
}

doReleaseShared

private void doReleaseShared() {
    // 无限循环
    for (;;) {
        // 保存头节点
        Node h = head;
        if (h != null && h != tail) { // 头节点不为空并且头节点不为尾结点
            // 获取头节点的等待状态
            int ws = h.waitStatus; 
            if (ws == Node.SIGNAL) { // 状态为SIGNAL
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) // 不成功就继续
                    continue;            // loop to recheck cases
                // 释放后继结点
                unparkSuccessor(h);
            }
            else if (ws == 0 &&
                        !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) // 状态为0并且不成功,继续
                continue;                // loop on failed CAS
        }
        if (h == head) // 若头节点改变,继续循环  
            break;
    }
}

CountDownLatch示例

package com.example.demo.thread.countdownlatch;

import java.util.concurrent.CountDownLatch;

public class CountDownLatchDemo {
    public static void main(String[] args) throws Exception {
        CountDownLatch latch = new CountDownLatch(2);

        new Thread(){
            @Override
            public void run() {
                try{
                    Thread.sleep(1000);
                    System.out.println("线程1开始执行,休眠1秒。。。。");
                    Thread.sleep(1000);

                    System.out.println("线程1执行countDown操作.....");
                    latch.countDown();

                    System.out.println("线程1完成countDown操作.....");
                }catch (Exception e){
                    e.printStackTrace();
                }
            }
        }.start();

        new Thread(){
            @Override
            public void run() {
                try{
                    Thread.sleep(1000);
                    System.out.println("线程2开始执行,休眠1秒。。。。");
                    Thread.sleep(1000);

                    System.out.println("线程2执行countDown操作.....");
                    latch.countDown();

                    System.out.println("线程2完成countDown操作.....");
                }catch (Exception e){
                    e.printStackTrace();
                }
            }
        }.start();

        System.out.println("main线程准备执行countDownLatch的await操作,将会同步阻塞等待");
        latch.await();

        System.out.println("阻塞等待结束");
    }
}

运行结果

image.png