从源码到实战:Java CountDownLatch深度剖析与原理揭秘

167 阅读10分钟

从源码到实战:Java CountDownLatch深度剖析与原理揭秘

一、引言

在多线程编程领域,线程间的协同与同步是确保程序正确性和高效性的关键。Java提供了丰富的并发工具类来解决这类问题,其中CountDownLatch作为一种强大的同步工具,能够有效控制线程的执行顺序,让一个或多个线程等待其他线程完成一系列操作后再继续执行。本文将从源码层面深入剖析CountDownLatch的实现原理,结合具体代码示例,带您全面了解其工作机制,揭开其神秘面纱。

二、CountDownLatch概述

2.1 定义与作用

CountDownLatch是Java并发包java.util.concurrent中的一个类,它允许一个或多个线程等待其他线程完成一组操作。它内部维护一个计数器,在初始化时指定计数器的初始值,每当一个相关线程完成操作后,计数器就减1,当计数器的值减为0时,等待在CountDownLatch上的所有线程将被释放,继续执行后续操作。

2.2 核心方法

CountDownLatch主要提供了以下两个核心方法:

  1. CountDownLatch(int count):构造函数,用于初始化CountDownLatch,参数count指定计数器的初始值。
  2. void countDown():调用该方法会将计数器的值减1,通常由完成任务的线程调用。
  3. void await():调用该方法的线程会被阻塞,直到计数器的值减为0。如果计数器已经为0,await方法会立即返回。

三、CountDownLatch的源码结构

3.1 类定义与继承关系

// java.util.concurrent.CountDownLatch类定义
public class CountDownLatch {
    // 继承自AbstractQueuedSynchronizer,利用AQS实现同步机制
    private final Sync sync; 

    // 内部静态类,继承自AbstractQueuedSynchronizer
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        // 构造函数,设置初始状态为count
        Sync(int count) {
            setState(count);
        }

        // 获取当前状态(即计数器的值)
        int getCount() {
            return getState();
        }

        // 尝试获取共享锁,始终返回false,因为CountDownLatch不支持获取锁的操作
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0)? 1 : -1;
        }

        // 尝试释放共享锁,成功时将计数器减1,当计数器为0时唤醒所有等待线程
        protected boolean tryReleaseShared(int releases) {
            // 采用CAS操作更新状态
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c - 1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

    // CountDownLatch构造函数,初始化Sync实例
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

    // 使当前线程等待,直到计数器的值减为0
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    // 使当前线程等待,直到计数器的值减为0,或者等待指定的时间
    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    // 将计数器的值减1
    public void countDown() {
        sync.releaseShared(1);
    }

    // 返回当前计数器的值
    public long getCount() {
        return sync.getCount();
    }

    // 重写toString方法,返回CountDownLatch的状态信息
    public String toString() {
        return super.toString() + "[Count = " + sync.getCount() + "]";
    }
}

从上述源码可以看出,CountDownLatch的实现依赖于AbstractQueuedSynchronizer(简称AQS),AQS是Java并发包中用于构建锁和同步器的基础框架。CountDownLatch通过内部类Sync继承AQS,并实现了tryAcquireSharedtryReleaseShared方法,来完成计数器的更新和线程的等待与唤醒操作。

3.2 关键成员变量

  1. sync:类型为Sync,是CountDownLatch的核心同步工具,负责处理计数器的更新和线程的同步操作。
  2. state:在AQS中定义的一个整型变量,用于表示同步状态。在CountDownLatch中,state的值就是计数器的值。

四、CountDownLatch的工作流程详解

4.1 初始化过程

当创建一个CountDownLatch实例时,会调用其构造函数:

// CountDownLatch构造函数
public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}

// Sync类的构造函数
Sync(int count) {
    setState(count);
}

在构造函数中,首先检查传入的count值是否小于0,如果小于0则抛出IllegalArgumentException异常。然后创建一个Sync实例,并将count值通过setState方法设置为AQS的同步状态state,即初始化计数器的值。

4.2 await方法执行流程

当线程调用await方法时,会执行以下操作:

// await方法
public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

await方法调用了Sync实例的acquireSharedInterruptibly方法,该方法在AQS中定义:

// AbstractQueuedSynchronizer中的acquireSharedInterruptibly方法
public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}
  1. 首先检查当前线程是否被中断,如果被中断则抛出InterruptedException异常。
  2. 调用tryAcquireShared方法尝试获取共享锁,该方法在Sync类中实现:
// Sync类中的tryAcquireShared方法
protected int tryAcquireShared(int acquires) {
    return (getState() == 0)? 1 : -1;
}

tryAcquireShared方法检查当前计数器的值(即state的值)是否为0,如果为0表示所有任务已完成,返回1表示获取共享锁成功;否则返回-1表示获取共享锁失败。 3. 如果tryAcquireShared方法返回-1,说明获取共享锁失败,调用doAcquireSharedInterruptibly方法将当前线程加入等待队列并阻塞:

// AbstractQueuedSynchronizer中的doAcquireSharedInterruptibly方法
private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    // 创建一个共享模式的节点
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        for (;;) {
            final Node p = node.predecessor();
            if (p == head) {
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt())
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

doAcquireSharedInterruptibly方法中,首先创建一个共享模式的节点并加入等待队列,然后进入一个无限循环:

  • 检查当前节点的前驱节点是否为头节点,如果是头节点则再次尝试获取共享锁(调用tryAcquireShared方法)。
  • 如果获取共享锁成功,调用setHeadAndPropagate方法将当前节点设置为头节点,并唤醒后续等待的节点。
  • 如果获取共享锁失败,调用shouldParkAfterFailedAcquire方法判断当前线程是否应该阻塞,如果应该阻塞则调用parkAndCheckInterrupt方法将线程阻塞。如果线程在阻塞期间被中断,则抛出InterruptedException异常。

4.3 countDown方法执行流程

当线程完成任务后,调用countDown方法将计数器的值减1:

// countDown方法
public void countDown() {
    sync.releaseShared(1);
}

countDown方法调用了Sync实例的releaseShared方法,该方法在AQS中定义:

// AbstractQueuedSynchronizer中的releaseShared方法
public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}
  1. 首先调用tryReleaseShared方法尝试释放共享锁,该方法在Sync类中实现:
// Sync类中的tryReleaseShared方法
protected boolean tryReleaseShared(int releases) {
    // 采用CAS操作更新状态
    for (;;) {
        int c = getState();
        if (c == 0)
            return false;
        int nextc = c - 1;
        if (compareAndSetState(c, nextc))
            return nextc == 0;
    }
}

tryReleaseShared方法通过一个无限循环,采用CAS(Compare And Swap,比较并交换)操作尝试将计数器的值减1。如果当前计数器的值为0,说明已经减到0,直接返回false;否则计算下一个状态值nextc,并使用compareAndSetState方法尝试更新状态。如果更新成功,检查nextc是否为0,如果为0说明所有任务已完成,返回true;否则返回false。 2. 如果tryReleaseShared方法返回true,说明释放共享锁成功,调用doReleaseShared方法唤醒等待队列中所有等待的线程:

// AbstractQueuedSynchronizer中的doReleaseShared方法
private void doReleaseShared() {
    for (;;) {
        Node h = head;
        if (h != null && h != tail) {
            int ws = h.waitStatus;
            if (ws == Node.SIGNAL) {
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, Node.CONDITION))
                    continue;            // loop to recheck cases
                unparkSuccessor(h);
            }
            else if (ws == 0 &&
                     !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                continue;                // loop on failed CAS
        }
        if (h == head)                   // loop if head changed
            break;
    }
}

doReleaseShared方法中,通过一个无限循环,检查头节点是否存在且不是尾节点。如果头节点的waitStatusNode.SIGNAL,说明头节点的后继节点需要被唤醒,使用compareAndSetWaitStatus方法将头节点的waitStatus设置为Node.CONDITION,然后调用unparkSuccessor方法唤醒后继节点。如果头节点的waitStatus为0,尝试将其设置为Node.PROPAGATE,以确保唤醒操作能够传播到后续节点。

五、CountDownLatch的典型应用场景

5.1 多个线程完成任务后再执行后续操作

import java.util.concurrent.CountDownLatch;

public class MultipleTasksExample {
    public static void main(String[] args) {
        // 初始化CountDownLatch,计数器初始值为3
        CountDownLatch latch = new CountDownLatch(3); 

        // 创建并启动三个线程
        new Thread(new Task(latch, "线程1")).start();
        new Thread(new Task(latch, "线程2")).start();
        new Thread(new Task(latch, "线程3")).start();

        try {
            // 主线程等待,直到计数器的值减为0
            latch.await();
            System.out.println("所有任务已完成,主线程继续执行...");
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    // 任务类,实现Runnable接口
    static class Task implements Runnable {
        private final CountDownLatch latch;
        private final String threadName;

        public Task(CountDownLatch latch, String threadName) {
            this.latch = latch;
            this.threadName = threadName;
        }

        @Override
        public void run() {
            try {
                System.out.println(threadName + " 开始执行任务...");
                // 模拟任务执行时间
                Thread.sleep(2000); 
                System.out.println(threadName + " 任务执行完毕");
            } catch (InterruptedException e) {
                e.printStackTrace();
            } finally {
                // 任务完成后,调用countDown方法将计数器减1
                latch.countDown(); 
            }
        }
    }
}

在上述示例中,主线程创建了一个CountDownLatch实例,初始计数器值为3。然后启动了三个线程,每个线程在完成任务后调用countDown方法将计数器减1。主线程调用await方法等待,直到计数器的值减为0,此时说明所有任务已完成,主线程继续执行后续操作。

5.2 模拟并发请求

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class ConcurrentRequestExample {
    public static void main(String[] args) {
        // 初始化CountDownLatch,计数器初始值为10
        CountDownLatch latch = new CountDownLatch(1); 
        // 创建线程池
        ExecutorService executorService = Executors.newFixedThreadPool(10); 

        for (int i = 0; i < 10; i++) {
            executorService.submit(new RequestTask(latch));
        }

        try {
            // 模拟准备工作
            Thread.sleep(2000); 
            System.out.println("准备工作完成,释放所有线程...");
            // 将计数器减为0,释放所有等待的线程
            latch.countDown(); 
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            // 关闭线程池
            executorService.shutdown(); 
        }
    }

    // 请求任务类,实现Runnable接口
    static class RequestTask implements Runnable {
        private final CountDownLatch latch;

        public RequestTask(CountDownLatch latch) {
            this.latch = latch;
        }

        @Override
        public void run() {
            try {
                // 等待计数器减为0
                latch.await();
                System.out.println(Thread.currentThread().getName() + " 开始发送请求...");
                // 模拟请求处理时间
                Thread.sleep(1000); 
                System.out.println(Thread.currentThread().getName() + " 请求处理完毕");
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }
    }
}

在这个示例中,创建了一个线程池并提交了10个任务,每个任务都调用await方法等待CountDownLatch的计数器减为0。主线程在完成准备工作后,调用countDown方法将计数器减为0,此时所有等待的线程将同时开始执行任务,模拟了并发请求的场景。

六、CountDownLatch与其他同步工具的对比

6.1 与CyclicBarrier的对比

  1. 功能差异
    • CountDownLatch主要用于一个或多个线程等待其他线程完成一组操作,计数器的值只能递减一次,不能重置。
    • CyclicBarrier用于一组线程相互等待,当所有线程都到达某个屏障点时,这些线程才会继续执行,并且可以重复使用。
  2. 实现原理差异
    • CountDownLatch基于AQS实现,通过计数器的递减和线程的等待与唤醒机制来实现同步。
    • CyclicBarrier内部维护一个计数器和一个可重入锁,通过条件变量来实现线程的等待和唤醒,并且在计数器归零时会执行一个回调函数(Runnable任务)。
  3. 适用场景差异
    • CountDownLatch适用于简单的一次性同步场景,例如等待多个任务完成后再执行后续操作。
    • CyclicBarrier适用于需要重复进行同步的场景,例如在多线程计算中,每次迭代都需要所有线程同步到某个点后再继续下一次迭代。

6.2 与Semaphore的对比

  1. 功能差异
    • CountDownLatch用于线程间的等待,直到计数器的值减为0,主要控制线程的执行顺序。
    • Semaphore用于控制同时访问某个资源的线程数量,通过获取和释放信号量来实现。
  2. 实现原理差异
    • CountDownLatch基于AQS的共享模式实现,通过计数器的更新和线程的等待唤醒机制来同步线程。
    • Semaphore同样基于AQS实现,但其内部维护的是信号量的数量,线程通过调用`acquire