【JUC】CyclicBarrier实现和原理

612 阅读2分钟

一、概述

CyclicBarrier可循环使用的屏障,主要功能是拦截一组线程,直到最后一个线程到达屏障,屏障才会放行;

二、常用方法

1、int await()

在指定的数量的线程未到达时,所有线程都会在此阻塞

2、int await(long timeout, TimeUnit unit)

只有在指定数量的线程到达时,或者等待时间超过timeout才会放行

3、boolean isBroken()

查询屏障是否损坏

4、void reset()

重置为初始状态,所有等待线程抛出BrokenBarrierException异常

5、int getNumberWaiting()

返回正在等待的线程数

5、int getParties()

返回需要到达屏障的线程数

三、源码解析

主要是基于ReentrantLock和Condition实现

1、构造函数

    public CyclicBarrier(int parties) {
        this(parties, null);
    }
    public CyclicBarrier(int parties, Runnable barrierAction) {
        // parties需要的线程参与数
        if (parties <= 0) throw new IllegalArgumentException();
        this.parties = parties;
        // 计数器,当count为0,则表示可以放行
        this.count = parties;
        // 屏障启动调用的自定义函数
        this.barrierCommand = barrierAction;
    }

2、await

    public int await() throws InterruptedException, BrokenBarrierException {
        try {
            return dowait(false, 0L);
        } catch (TimeoutException toe) {
            throw new Error(toe); // cannot happen
        }
    }
    private int dowait(boolean timed, long nanos)
        throws InterruptedException, BrokenBarrierException,
               TimeoutException {
        final ReentrantLock lock = this.lock;
        // 加锁,保证线程安全
        lock.lock();
        try {
            // 静态内部类,持有屏障是否损坏
            final Generation g = generation;

            if (g.broken)
                // 如果损坏则抛出异常
                throw new BrokenBarrierException();

            if (Thread.interrupted()) {
                // 线程中断则将屏障设置为损坏,count重置为parties,唤醒所有等待线程,所有线程都将抛出异常
                breakBarrier();
                throw new InterruptedException();
            }

            int index = --count;
            if (index == 0) {  // tripped
                // 计数器为0,表示所有线程都已到达
                boolean ranAction = false;
                try {
                    final Runnable command = barrierCommand;
                    if (command != null)
                        // 如果不为null则执行创建CyclicBarrier时定义的动作,一般没有
                        command.run();
                    ranAction = true;
                    // 唤醒所有等待线程并重置parties,换代,创建新的Generation
                    nextGeneration();
                    return 0;
                } finally {
                    // 如果自定义的操作报错则设置屏障损坏
                    if (!ranAction)
                        breakBarrier();
                }
            }

            // 这里表示计数器未归零,进入等待,自旋直到指定线程数到达、损坏、中断、超时
            for (;;) {
                try {
                    // 使用condition的await实现阻塞
                    if (!timed)
                        trip.await();
                    else if (nanos > 0L)
                        nanos = trip.awaitNanos(nanos);
                } catch (InterruptedException ie) {
                    if (g == generation && ! g.broken) {
                        breakBarrier();
                        throw ie;
                    } else {
                        // We're about to finish waiting even if we had not
                        // been interrupted, so this interrupt is deemed to
                        // "belong" to subsequent execution.
                        Thread.currentThread().interrupt();
                    }
                }
                // 唤醒后操作
                // 损坏则抛出异常
                if (g.broken)
                    throw new BrokenBarrierException();
                // 如果换代则返回index
                if (g != generation)
                    return index;

                if (timed && nanos <= 0L) {
                    // 超时异常,设置为损坏
                    breakBarrier();
                    throw new TimeoutException();
                }
            }
        } finally {
            lock.unlock();
        }
    }

使用ReentrantLock保证线程安全,使用Condition来实现阻塞和线程唤醒 Generation对象用于换代,来实现可循环使用

换代方法:重置count,换代

    private void nextGeneration() {
        // signal completion of last generation
        trip.signalAll();
        // set up next generation
        count = parties;
        generation = new Generation();
    }

损坏方法:重置count,设置broken为true

    private void breakBarrier() {
        generation.broken = true;
        count = parties;
        trip.signalAll();
    }

3、await(long timeout, TimeUnit unit)

带有超时时间的await,主体方法还是调用上面介绍的dowait方法

    public int await(long timeout, TimeUnit unit)
        throws InterruptedException,
               BrokenBarrierException,
               TimeoutException {
        return dowait(true, unit.toNanos(timeout));
    }

4、reset

reset方法很简单就是先设置屏障损坏,让所有等待线程唤醒并抛出BrokenBarrierException,再执行换代方法重置计数器

    public void reset() {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            breakBarrier();   // break the current generation
            nextGeneration(); // start a new generation
        } finally {
            lock.unlock();
        }
    }