30行自己写并发工具类(Semaphore, CyclicBarrier, CountDownLatch)

82 阅读10分钟

前言

在本篇文章当中首先给大家介绍三个工具 Semaphore, CyclicBarrier, CountDownLatch 该如何使用,然后仔细剖析这三个工具内部实现的原理,最后会跟大家一起用 ReentrantLock 实现这三个工具。

并发工具类的使用

CountDownLatch

CountDownLatch 最主要的作用是允许一个或多个线程等待其他线程完成操作。比如我们现在有一个任务,有 (N) 个线程会往数组 data[N] 当中对应的位置根据不同的任务放入数据,在各个线程将数据放入之后,主线程需要将这个数组当中所有的数据进行求和计算,也就是说主线程在各个线程放入之前需要阻塞住!在这样的场景下,我们就可以使用 CountDownLatch 。

上面问题的代码:

import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.CountDownLatch;

public class CountDownLatchDemo {

    public static int[] data = new int[10];

    public static void main(String[] args) throws InterruptedException {
        CountDownLatch latch = new CountDownLatch(10);

        for (int i = 0; i < 10; i++) {
            int temp = i;
            new Thread(() -> {
                Random random = new Random();
                data[temp] = random.nextInt(100001);
                latch.countDown();
            }).start();
        }

        // 只有函数 latch.countDown() 至少被调用10次
        // 主线程才不会被阻塞
        // 这个10是在CountDownLatch初始化传递的10
        latch.await();
        System.out.println("求和结果为:" + Arrays.stream(data).sum());
    }
}

在上面的代码当中,主线程通过调用 latch.await(); 将自己阻塞住,然后需要等他其他线程调用方法 latch.countDown() 只有这个方法被调用的次数等于在初始化时给 CountDownLatch传递的参数时,主线程才会被释放。

CyclicBarrier

CyclicBarrier 它要做的事情是,让一 组线程到达一个屏障(也可以叫同步点)时被阻塞,直到最后一个线程到达屏障时,屏障才会开门,所有被屏障拦截的线程才会继续运行。我们通常也将 CyclicBarrier 称作 路障 。

示例代码:

public class CycleBarrierDemo {

    public static void main(String[] args) {
        CyclicBarrier barrier = new CyclicBarrier(5);

        for (int i = 0; i < 5; i++) {
            new Thread(() -> {
                try {
                    System.out.println(Thread.currentThread().getName() + "开始等待");
                    // 所有线程都会调用这行代码
                    // 在这行代码调用的线程个数不足5
                    // 个的时候所有的线程都会阻塞在这里
                    // 只有到5的时候,这5个线程才会被放行
                    // 所以这行代码叫做同步点 
                    barrier.await();
                    // 如果有第六个线程执行这行代码时
                    // 第六个线程也会被阻塞 知道第10
                    // 线程执行这行代码 6-10 这5个线程
                    // 才会被放行
                } catch (InterruptedException e) {
                    e.printStackTrace();
                } catch (BrokenBarrierException e) {
                    e.printStackTrace();
                }
                System.out.println(Thread.currentThread().getName() + "等待完成");
            }).start();
        }
    }
}

我们在初始化 CyclicBarrier 对象时,传递的数字为 5 ,这个数字表示只有5个线程到达同步点的时候,那5个线程才会同时被放行,而如果到了6个线程的话,第一次没有被放行的线程必须等到下一次有 5 个线程到达同步点 barrier.await() 时,才会放行5个线程。

  • 比如刚开始的时候5个线程的状态如下, 同步点 还没有5个线程到达,因此不会放行。

  • 当有5个线程或者更多的线程到达 同步点 barrier.await() 的时候,才会放行 5 个线程,注意是5个线程,如果有多的线程必须等到下一次集合5个线程才会进行又一次放行,也就是说每次只放行5个线程,这也是它叫做 CyclicBarrier (循环路障)的原因(因为每次放行5个线程,放行完之后重新计数,直到又有5个新的线程到来,才再次放行)。

Semaphore

Semaphore ( 信号量 )通俗一点地来说就是控制能执行某一段代码的线程数量,他可以控制程序的并发量!

semaphore.acquire

(\mathcal{R})

semaphore.release

比如在上面的 acquire 和 release 之间的代码 (\mathcal{R}) 就是我们需要控制的代码,我们可以通过 信号量 控制在某一个时刻能有多少个线程执行代码 (\mathcal{R}) 。在信号量内部有一个计数器,在我们初始化的时候设置为 (N) ,当有线程调用 acquire 函数时,计数器需要减一,调用 release 函数时计数器需要加一,只有当计数器大于0时,线程调用 acquire 时才能够进入代码块 (\mathcal{R}) ,否则会被阻塞,只有线程调用 release 函数时,被阻塞的线程才能被唤醒,被唤醒的时候计数器会减一。

示例代码:

import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;

public class SemaphoreDemo {
    public static void main(String[] args) {
        Semaphore mySemaphore = new Semaphore(5);
        for (int i = 0; i < 10; i++) {
            new Thread(() -> {
                System.out.println(Thread.currentThread().getName() + "准备进入临界区");
                try {
                    mySemaphore.acquire();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                System.out.println(Thread.currentThread().getName() + "已经进入临界区");
                try {
                    TimeUnit.SECONDS.sleep(2);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                System.out.println(Thread.currentThread().getName() + "准备离开临界区");
                mySemaphore.release();
                System.out.println(Thread.currentThread().getName() + "已经离开临界区");
            }).start();
        }
    }
}

自己动手写并发工具类

在这一小节当中主要使用 ReentrantLock 实现上面我们提到的三个并发工具类,因此你首先需要了解 ReentrantLock 这个工具。 ReentrantLock 中有两个主要的函数 lock 和 unlock,主要用于临界区的保护,在同一个时刻只能有一个线程进入被 lock 和 unlock 包围的代码块。除此之外你还需要了解
ReentrantLock.newCondition 函数,这个函数会返回一个条件变量 Condition ,这个条件变量有三个主要的函数 await 、 signal 和 signalAll ,这三个函数的作用和效果跟 Object 类的 wait 、 notify 和 notifyAll 一样,在阅读下文之前,大家首先需要了解他们的用法。

  • 哪个线程调用函数 condition.await ,那个线程就会被挂起。
  • 如果线程调用函数 conditon.signal ,则会唤醒一个被 condition.await 函数阻塞的线程。
  • 如果线程调用函数 conditon.signalAll ,则会唤醒所有被 condition.await 函数阻塞的线程。

CountDownLatch

我们在使用 CountDownLatch 时,会有线程调用 CountDownLatch 的 await 函数,其他线程会调用 CountDownLatch 的 countDown 函数。在 CountDownLatch 内部会有一个计数器,计数器的值我们在初始化的时候可以进行设置,线程每调用一次 countDown 函数计数器的值就会减一。

  • 如果在线程在调用 await 函数之前,计数器的值已经小于或等于0时,调用 await 函数的线程就不会阻塞,直接放行。
  • 如果在线程在调用 await 函数之前,计数器的值大于0时,调用 await 函数的线程就会被阻塞,当有其他线程将计数器的值降低为0时,那么这个将计数器降低为0线程就需要使用 condition.signalAll() 函数将其他所有被 await 阻塞的函数唤醒。
  • 线程如果想阻塞自己的话可以使用函数 condition.await() ,如果某个线程在进入临界区之后达到了唤醒其他线程的条件,我们则可以使用函数 condition.signalAll() 唤醒所有被函数 await 阻塞的线程。

上面的规则已经将 CountDownLatch 的整体功能描述清楚了,为了能够将代码解释清楚,我将对应的文字解释放在了代码当中:

import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;

public class MyCountDownLatch {
    private ReentrantLock lock = new ReentrantLock();
    private Condition condition = lock.newCondition();
    private int curValue;

    public MyCountDownLatch(int targetValue) {
        // 我们需要有一个变量去保存计数器的值
        this.curValue = targetValue;
    }

    public void countDown() {
        // curValue 是一个共享变量
        // 我们需要用锁保护起来
        // 因此每次只有一个线程进入 lock 保护
        // 的代码区域
        lock.lock();
        try {
            // 每次执行 countDown 计数器都需要减一
            // 而且如果计数器等于0我们需要唤醒哪些被
            // await 函数阻塞的线程
            curValue--;
            if (curValue <= 0)
                condition.signalAll();
        }catch (Exception ignored){}
        finally {
            lock.unlock();
        }
    }

    public void await() {
        lock.lock();
        try {
            // 如果 curValue 的值大于0
            // 则说明 countDown 调用次数还不够
            // 需要将线程挂起 否则直接放行
            if (curValue > 0)
                // 使用条件变量 condition 将线程挂起
                condition.await();
        }catch (Exception ignored){}
        finally {
            lock.unlock();
        }
    }
}

可以使用下面的代码测试我们自己写的 CountDownLatch :

public static void main(String[] args) throws InterruptedException {
    MyCountDownLatch latch = new MyCountDownLatch(5);
    for (int i = 0; i < 3; i++) {
        new Thread(() -> {
            latch.countDown();
            System.out.println(Thread.currentThread().getName() + "countDown执行完成");
        }).start();
    }

    for (int i = 0; i < 10; i++) {
        new Thread(() -> {
            try {
                TimeUnit.SECONDS.sleep(2);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            try {
                latch.await();
            } catch (Exception e) {
                e.printStackTrace();
            }
            System.out.println(Thread.currentThread().getName() +  "latch执行完成");
        }).start();
    }
}

CyclicBarrier

CyclicBarrier 有一个路障(同步点),所有的线程到达路障之后都会被阻塞,当被阻塞的线程个数达到指定的数目的时候,就需要对指定数目的线程进行放行。

  • 在 CyclicBarrier 当中会有一个数据 threadCount ,表示在路障需要达到这个 threadCount 个线程的时候才进行放行,而且需要放行 threadCount 个线程,这里我们可以循环使用函数 condition.signal() 去唤醒指定个数的线程,从而将他们放行。如果线程需要将自己阻塞住,可以使用函数 condition.await() 。
  • 在 CyclicBarrier 当中需要有一个变量 currentThreadNumber ,用于记录当前被阻塞的线程的个数。
  • 用户还可以给 CyclicBarrier 传入一个 Runnable 对象,当放行的时候需要执行这个 Runnable 对象,你可以新开一个线程去执行这个 Runnable 对象,或者让唤醒其他线程的这个线程执行 Runnable 对象。

根据上面的 CyclicBarrier 要求,写出的代码如下(分析和解释在注释当中):

import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;

public class MyCyclicBarrier {

    private ReentrantLock lock = new ReentrantLock();
    private Condition condition = lock.newCondition();
    private int threadCount;
    private int currentThreadNumber;
    private Runnable runnable;

    public MyBarrier(int count) {
        threadCount = count;
    }

    /**
     * 允许传入一个 runnable 对象
     * 当放行一批线程的时候就执行这个 runnable 函数
     * @param count
     * @param runnable
     */
    public MyBarrier(int count, Runnable runnable) {
        this(count);
        this.runnable = runnable;
    }
    
    public void await() {
        lock.lock();
        currentThreadNumber++;
        try {
            // 如果阻塞的线程数量不到 threadCount 需要进行阻塞
            // 如果到了需要由这个线程唤醒其他线程
            if (currentThreadNumber == threadCount) {
                // 放行之后需要重新进行计数
                // 因为放行之后 condition.await();
                // 阻塞的线程个数为 0
                currentThreadNumber = 0;
                if (runnable != null) {
                    new Thread(runnable).start();
                }
                // 唤醒 threadCount - 1 个线程 因为当前这个线程
                // 已经是在运行的状态 所以只需要唤醒 threadCount - 1
                // 个被阻塞的线程
                for (int i = 1; i < threadCount; i++)
                    condition.signal();
            }else {
                // 如果数目还没有达到则需要阻塞线程
                condition.await();
            }
        }catch (Exception ignored){}
        finally {
            lock.unlock();
        }
    }

}

下面是测试我们自己写的 路障 的代码:

public static void main(String[] args) throws InterruptedException {
    MyCyclicBarrier barrier = new MyCyclicBarrier(5, () -> {
        System.out.println(Thread.currentThread().getName() + "开启一个新线程");
        for (int i = 0; i < 1; i++) {
            System.out.println(i);
        }
    });

    for (int i = 0; i < 5; i++) {
        new Thread(() -> {
            System.out.println(Thread.currentThread().getName() + "进入阻塞");
            try {
                TimeUnit.SECONDS.sleep(2);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            barrier.await();
            System.out.println(Thread.currentThread().getName() + "阻塞完成");
        }).start();
    }
}

Semaphore

Semaphore 可以控制执行某一段临界区代码的线程数量,在 Semaphore 当中会有两个计数器 semCount 和 curCount 。

semCount
curCount

这个工具实现起来也并不复杂,具体分析都在注释当中:

import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;

public class MySemaphore {

    private ReentrantLock lock = new ReentrantLock();
    private Condition condition = lock.newCondition();
    private int semCount;
    private int curCount;

    public MySemaphore(int semCount) {
        this.semCount = semCount;
    }

    public void acquire() {
        lock.lock();
        try {
            // 正在执行临界区代码的线程个数加一
            curCount++;
            // 如果线程个数大于指定的能够执行的线程个数
            // 需要将当前这个线程阻塞起来
            // 否则直接放行
            if (curCount > semCount) {
                condition.await();
            }
        }catch (Exception ignored) {}
        finally {
            lock.unlock();
        }
    }

    public void release() {
        lock.lock();
        try {
            // 线程执行完临界区的代码
            // 将要离开临界区 因此 curCount 
            // 需要减一
            curCount--;
            // 如果有线程阻塞需要唤醒被阻塞的线程
            // 如果没有被阻塞的线程 这个函数执行之后
            // 对结果也不会产生影响 因此在这里不需要进行
            // if 判断
            condition.signal();
            // signal函数只对在调用signal函数之前
            // 被await函数阻塞的线程产生影响 如果
            // 某个线程调用 await 函数在 signal 函数
            // 执行之后,那么前面那次 signal 函数调用
            // 不会影响后面这次 await 函数
        }catch (Exception ignored){}
        finally {
            lock.unlock();
        }
    }
}

使用下面的代码测试我们自己写的 MySemaphore :

public static void main(String[] args) {
    MySemaphore mySemaphore = new MySemaphore(5);
    for (int i = 0; i < 10; i++) {
        new Thread(() -> {
            mySemaphore.acquire();
            System.out.println(Thread.currentThread().getName() + "已经进入临界区");
            try {
                TimeUnit.SECONDS.sleep(2);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            mySemaphore.release();
            System.out.println(Thread.currentThread().getName() + "已经离开临界区");
        }).start();
    }
}

总结

在本文当中主要给大家介绍了三个在并发当中常用的工具类该如何使用,然后介绍了我们自己实现三个工具类的细节,其实主要是利用 条件变量 实现的,因为它可以实现线程的阻塞和唤醒,其实只要大家了解 条件变量 的使用方法,和三种工具的需求大家也可以自己实现一遍。

原文链接:
www.cnblogs.com/Chang-LeHun…