基于 AQS 分析 CyclicBarrier

606 阅读6分钟

基于 AQS 分析 CyclicBarrier

纸上得来终觉浅,绝知此事要躬行 —— 陆游 「这是我参与2022首次更文挑战的第7天,活动详情查看:2022首次更文挑战

代码案例

public class CyclicBarrierDemo {
    static class TaskThread extends Thread {
        CyclicBarrier barrier;
        public TaskThread(CyclicBarrier barrier) {
            this.barrier = barrier;
        }

        @Override
        public void run() {
            try {
                Thread.sleep(1000);
                System.out.println(getName() + " 到达栅栏 A");
                barrier.await();
                System.out.println(getName() + " 继续执行");
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

    public static void main(String[] args) {
        int threadNum = 5;
        CyclicBarrier barrier = new CyclicBarrier(threadNum, () -> System.out.println(Thread.currentThread().getName() + " 完成最后任务"));

        for (int i = 0; i < threadNum; i++) {
            new TaskThread(barrier).start();
        }
    }
}

结果输出

image.png

代码解释

就是开启5个线程等待这个五个线程都到执行到了await之后再执行传入进去的线程执行的任务,就相当于全班级去旅游,只有等最后一个人到了车上,等他关了车门才能发车,之后这些人再去完成旅游的这个任务

源码分析

先看一下构造函数是怎么样的

public CyclicBarrier(int parties, Runnable barrierAction) {
    // 传入进去的这个数量也就是 parties,也就是线程数量必须是大于零的
    if (parties <= 0) throw new IllegalArgumentException();
    // 其实就是定义要达到的线程执行await方法的数量才能执行
    this.parties = parties;
    // 等待线程执行await方法的数量,每一次都是从定义的减少到0,都到了之后可以进行重置
    this.count = parties;
    // 到达指定线程调用await方法的数量之后需要执行的方法
    this.barrierCommand = barrierAction;
}

我们这里定义的 parties 是5,完成后输出 System.out.println(Thread.currentThread().getName() + " 完成最后任务"

await() 方法分析

看看这个await方法究竟做了什么

public int await() throws InterruptedException, BrokenBarrierException {
    try {
        return dowait(false, 0L);
    } catch (TimeoutException toe) {
        throw new Error(toe); // cannot happen
    }
}

重要的是这个dowait方法,默认传入了一个 false ,一个 0,看看这里究竟做了什么,这个方法有点长 我们慢慢看

private int dowait(boolean timed, long nanos)throws InterruptedException, BrokenBarrierException,TimeoutException {
    // 定义了一个ReentrantLock锁
    final ReentrantLock lock = this.lock;
    // 先上个锁
    lock.lock();
    try {
        // 获取当前一代,就是当前循环的障碍,并且里面定义了一个属性来判断是栅栏是否被打破
        final Generation g = generation;
        // 默认是false 目前没有被打破
        if (g.broken)
            throw new BrokenBarrierException();

        if (Thread.interrupted()) {
            breakBarrier();
            throw new InterruptedException();
        }
        // count - 1需要等待线程执行await方法的数量减少1 
        int index = --count; // 5-1
        // 此时index等于4
        if (index == 0) {  // tripped
            boolean ranAction = false;
            try {
                final Runnable command = barrierCommand;
                if (command != null)
                    command.run();
                ranAction = true;
                nextGeneration();
                return 0;
            } finally {
                if (!ranAction)
                    breakBarrier();
            }
        }

        // loop until tripped, broken, interrupted, or timed out
        for (;;) {
            // 这边一直在循环着,除非被打破,中断、或者超时才退出
            try {
                // timed 传的是false,此逻辑成立
                if (!timed)
                    trip.await();
                // nanos == 0
                else if (nanos > 0L)
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
                if (g == generation && ! g.broken) {
                    breakBarrier();
                    throw ie;
                } else {
                    Thread.currentThread().interrupt();
                }
            }

            if (g.broken)
                throw new BrokenBarrierException();
			// g == generation
            if (g != generation)
                return index;
            if (timed && nanos <= 0L) {
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        lock.unlock();
    }
}

timed 传的是false,此if (!timed)逻辑成立,此时走这个方法  trip.await();看看这个trip是什么东西

private final Condition trip = lock.newCondition();// 创建了一个Condition

在这里创建了一个Condition,我们先猜猜看,看看是不是将当前线程放入到了 Condition 的条件等待队列中去了

public final void await() throws InterruptedException {
    if (Thread.interrupted()) throw new InterruptedException();
    // 看下面的这个方法应该是将当前的线程放入了等待队列中去了
    Node node = addConditionWaiter();
    int savedState = fullyRelease(node);
    int interruptMode = 0;
    while (!isOnSyncQueue(node)) {
        LockSupport.park(this);
        if ((interruptMode = checkInterruptWhileWaiting(node)) != 0)
            break;
    }
    if (acquireQueued(node, savedState) && interruptMode != THROW_IE)
        interruptMode = REINTERRUPT;
    if (node.nextWaiter != null) // clean up if cancelled
        unlinkCancelledWaiters();
    if (interruptMode != 0)
        reportInterruptAfterWait(interruptMode);
}

看看这个addCondititonWaiter 这个方法是不是将当前的这个等待者放入到了条件队列中去了

private Node addConditionWaiter() {
    // 第一次执行这个方法的时候,定义了一个 t 指针指向了条件队列的最后一个等待者,不过此时lastWaiter是null的
    Node t = lastWaiter;
    if (t != null && t.waitStatus != Node.CONDITION) {
        unlinkCancelledWaiters();
        t = lastWaiter;
    }
    // 第一次来这里所以走到这里创建一个Node节点
    Node node = new Node(Thread.currentThread(), Node.CONDITION);
    // 条件队列的第一个指针也指向了这个node节点
    if (t == null)
        firstWaiter = node;
    else
        t.nextWaiter = node;
    // 尾部指针也指向了这个node节点
    lastWaiter = node;
    // 返回新创建的节点
    return node;
}

果真如此之前好像也分析过这个方法在 Condition 中,画一个图看看条件队列中的样子
image.png
之后走到了这个方法中fullyRelease(node),感觉像是释放锁的方法,我们进去看看

final int fullyRelease(Node node) {
    boolean failed = true;
    try {
        // 获取当前的state的值,此时的值是 1,这个1 是在加锁的时候设置的
        int savedState = getState();
        if (release(savedState)) {
            failed = false;
            return savedState;
        } else {
            throw new IllegalMonitorStateException();
        }
    } finally {
        if (failed)
            node.waitStatus = Node.CANCELLED;
    }
}

看这个方法release(savedState),好像是将state的值减少1

    public final boolean release(int arg) {
        if (tryRelease(arg)) {
            Node h = head;
            if (h != null && h.waitStatus != 0)
                unparkSuccessor(h);
            return true;
        }
        return false;
    }

先走的是tryRelease方法,应该是在该方法中减少的值

protected final boolean tryRelease(int releases) {
    // 1 - 1 =0
    int c = getState() - releases;
    if (Thread.currentThread() != getExclusiveOwnerThread())
        throw new IllegalMonitorStateException();
    boolean free = false;
    if (c == 0) {
        free = true;
        setExclusiveOwnerThread(null);
    }
    setState(c);
    return free;
}

释放成功了 ,独占线程也设置成null了,state也设置成0了,并且返回了true,往下走,看定义了一个h指针指向了head,不过此时head还是空值,所以直接返回了true,所以此时fullyRelease(Node node)这个方法返回了1,接着往下看,此时走到了这个while的逻辑

while (!isOnSyncQueue(node)) {
    LockSupport.park(this);
    if ((interruptMode = checkInterruptWhileWaiting(node)) != 0)
        break;
}

看一下这个方法 !isOnSyncQueue(node))

final boolean isOnSyncQueue(Node node) {
    // 当前节点的 waitStatus 值确实是 Node.CONDITION,所以返回了false
    if (node.waitStatus == Node.CONDITION || node.prev == null)
        return false;
    if (node.next != null)
        return true;
    return findNodeFromTail(node);
}

当前节点的 waitStatus 值确实是 Node.CONDITION,所以返回了false,那么 !isOnSyncQueue(node) 就是 true了,通过LockSupport.park(this)将当前线程挂起,注意此时将当前线程挂起到了Condition的条件队列中去了,这样陆陆续续的调用了5次await的话那么此时条件等待队列中的样子如下图所是
image.png
继续往下分析,当达到五个线程调用了这个await方法之后看往下该怎么进行,核心代码如下,我只截取了一部分,不过前面是有的

int index = --count; // 此时index等于0
if (index == 0) {  // tripped
    boolean ranAction = false;
    try {
        final Runnable command = barrierCommand;
        if (command != null)
            // 运行我们传进来的任务方法
            command.run();
        ranAction = true;
        // 看看这块做什么了
        nextGeneration();
        return 0;
    } finally {
        if (!ranAction)
            breakBarrier();
    }
}

nextGeneration()里面的代码实现,感觉就是唤醒我们之前被阻塞到条件等待队列中的线程

private void nextGeneration() {  
    trip.signalAll();
    // 重置了一下需要等待调用await方法次数
    count = parties;
    // 开启下一代
    generation = new Generation();
}

分析一下这个 signalAll 方法

public final void signalAll() {
    if (!isHeldExclusively())
        throw new IllegalMonitorStateException();
    Node first = firstWaiter;
    if (first != null)
        doSignalAll(first);
}

定义了一个first指针指向了条件等待队列的第一个等待节点
image.png
因为first不为空所以走到了doSignalAll(first)方法中,我们进入这里看看是不是将当前队列中的节点加入到AQS的队列中去

private void doSignalAll(Node first) {
    // 将条件对列的两个头尾指针都置为空
    lastWaiter = firstWaiter = null;
    do {
        // 获取下一个节点
        Node next = first.nextWaiter;
        // 指针变换将first的下一个等待者指向的指针设置为null
        first.nextWaiter = null;
        // 这个方法应该是加入到AQS队列中我们看看
        transferForSignal(first);
        first = next;
    } while (first != null);
}

第一次是这个样子的,头节点出队,并且指针进行了变换
image.png
之后first指针当作游标将对了中的所有节点都出队,通过transferForSignal 下面调用的 enq 方法加入到AQS中去,此时条件队列中的节点为空了,啥都没有了,分析一下这个transferForSignal 方法

final boolean transferForSignal(Node node) {
    if (!compareAndSetWaitStatus(node, Node.CONDITION, 0))
        return false;
    // 其实这个就是入队
    Node p = enq(node);
    int ws = p.waitStatus;
    if (ws > 0 || !compareAndSetWaitStatus(p, ws, Node.SIGNAL))
        LockSupport.unpark(node.thread);
    return true;
}

看一下这个 enq(node) 方法我们分析好多次了

private Node enq(final Node node) {
    for (;;) {
        // 定义了一个t指向了tail 此时AQS中的tail和head都是null
        Node t = tail;
        if (t == null) { // Must initialize
            // 所以初始化创建了一个空的Node,并且将tail和head都指向了当前的Node节点
            if (compareAndSetHead(new Node()))
                tail = head;
        } else {
            // 后续就是改变指针移动tail指针了
            node.prev = t;
            if (compareAndSetTail(t, node)) {
                t.next = node;
                return t;
            }
        }
    }
}

image.png
返回了头节点,此时 if (ws > 0 || !compareAndSetWaitStatus(p, ws, Node.SIGNAL)) 这里面的逻辑都是false 因为头节点的waitStatus 值是0,后面的设置成功将头节点设置成了-1 ,这样一来所有的节点都放入到了AQS的等待队列中去了,并且将对列中的前面的节点的waitStaus都改成了-1,也就是Signal
image.png
之后走到了最后的finally 方法中释放锁的时候将这些节点一一唤醒去继续获取锁并且执行,我们继续看一下

lock.unlock();

释放锁我们也分析过好多遍了,再来看看

public void unlock() {
    sync.release(1);
}

里面调用的是sync的release方法,再进来看看

public final boolean release(int arg) {
    if (tryRelease(arg)) {
        Node h = head;
        if (h != null && h.waitStatus != 0)
            unparkSuccessor(h);
        return true;
    }
    return false;
}

继续看这个tryRelease(arg) 方法

protected final boolean tryRelease(int releases) {
    // 此时state是1,1-1 = 0
    int c = getState() - releases;
    if (Thread.currentThread() != getExclusiveOwnerThread())
        throw new IllegalMonitorStateException();
    boolean free = false;
    // 逻辑成立
    if (c == 0) {
        free = true;
        // 设置独占线程为null
        setExclusiveOwnerThread(null);
    }
    // 设置state = 0
    setState(c);
    return free;
}

所以此时这个tryRelease(arg) 方法返回的是true,此时定义了一个h指针指向了,AQS队列中的头节点head指针
image.png
此时h【头节点】不是null并且头节点的waitStatus是-1所以进行后续节点的唤醒,看看后续的唤醒方法

private void unparkSuccessor(Node node) { // 传进来的是头节点
    // 头节点等于-1
    int ws = node.waitStatus;
    if (ws < 0)
        // 将头节点的-1 改成了0
        compareAndSetWaitStatus(node, ws, 0);
    // 定义了一个s指向了node节点的下一个节点
    Node s = node.next;
    if (s == null || s.waitStatus > 0) {
        s = null;
        for (Node t = tail; t != null && t != node; t = t.prev)
            if (t.waitStatus <= 0)
                s = t;
    }
    // 此时s不为空所以被唤醒
    if (s != null)
        LockSupport.unpark(s.thread);
}

image.png
那么唤醒之后的逻辑是啥呢,看看在哪里被挂起的,看看之前这些线程都是哪里挂起的
image.png
此时被唤醒后去获取锁走的是acquireQueued方法

final boolean acquireQueued(final Node node, int arg) {
    boolean failed = true;
    try {
        boolean interrupted = false;
        for (;;) {
            // 当前节点的前驱节点是头节点,所以再次尝试获取锁
            final Node p = node.predecessor();
            if (p == head && tryAcquire(arg)) {
                setHead(node);
                p.next = null; // help GC
                failed = false;
                return interrupted;
            }
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt())
                interrupted = true;
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

再次尝试获取锁

protected final boolean tryAcquire(int acquires) {
    return nonfairTryAcquire(acquires);
}

走的是ReentrantLock的非公平获取锁的方式

final boolean nonfairTryAcquire(int acquires) {  // 1
    // 获取当前线程
    final Thread current = Thread.currentThread();
    // 此时state 的值为0
    int c = getState();
    if (c == 0) {
        // 将state的值设置为1
        if (compareAndSetState(0, acquires)) {
            // 设置独占线程,并且返回true
            setExclusiveOwnerThread(current);
            return true;
        }
    }
    else if (current == getExclusiveOwnerThread()) {
        int nextc = c + acquires;
        if (nextc < 0) // overflow
            throw new Error("Maximum lock count exceeded");
        setState(nextc);
        return true;
    }
    return false;
}

那么此时就走下面得逻辑了,设置state值后续的逻辑

if (p == head && tryAcquire(arg)) {
    setHead(node);
    p.next = null; // help GC
    failed = false;
    return interrupted;
}

设置头节点通过setHead方法

private void setHead(Node node) {
    head = node;
    node.thread = null;
    node.prev = null;
}

image.png
此时第一个AQS队列中非空节点就可以出队列进行任务的继续执行了