JUC. CyclicBarrier

43 阅读2分钟

# JUC. CyclicBarrier

运行一下

public class CyclicBarrierDemo {
    private static int N = 10;
    private static volatile AtomicInteger current = new AtomicInteger(0);

    public static void main(String[] args) {
        CyclicBarrier cyclicBarrier = new CyclicBarrier(N, () -> {
            System.out.println("我们都到了");
        });

        for (int i = 0; i < N; i++) {
            new Thread(() -> {
                System.out.println("current:" + current.incrementAndGet());

                try {
                    cyclicBarrier.await();
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                } catch (BrokenBarrierException e) {
                    throw new RuntimeException(e);
                }

                System.out.println("wait:" + current.get());
            }).start();
        }

        try {
            Thread.sleep(2000);
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        System.out.println("End.");
    }
}

输出

current:1
current:2
current:3
current:4
current:5
current:6
current:7
current:8
current:9
current:10
我们都到了
wait:10
wait:10
wait:10
wait:10
wait:10
wait:10
wait:10
wait:10
wait:10
wait:10
End.

核心部分

  1. 初始化循环屏障
CyclicBarrier cyclicBarrier = new CyclicBarrier(N, () -> {
    System.out.println("我们都到了");
});
  1. 线程中阻塞
cyclicBarrier.await();

核心实现

初始化

public CyclicBarrier(int parties, Runnable barrierAction) {
    if (parties <= 0) throw new IllegalArgumentException();
    this.parties = parties;
    this.count = parties;
    this.barrierCommand = barrierAction;
}

阻塞等待

public int await() throws InterruptedException, BrokenBarrierException {
    try {
        return dowait(false, 0L);
    } catch (TimeoutException toe) {
        throw new Error(toe); // cannot happen
    }
}

   

private int dowait(boolean timed, long nanos)
    throws InterruptedException, BrokenBarrierException,
           TimeoutException {
    // 互斥锁
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        // 中断器
        final Generation g = generation;
        if (g.broken)
            throw new BrokenBarrierException();
        // 线程被中断
        if (Thread.interrupted()) {
            breakBarrier();
            throw new InterruptedException();
        }
        // 减计数
        int index = —count;
        // 计数0,全部到达
        if (index == 0) {  // tripped
            boolean ranAction = false;
            try {
                // 执行预置命令
                final Runnable command = barrierCommand;
                if (command != null)
                    command.run();
                ranAction = true;

                // 更新屏蔽,并唤醒其他阻塞线程
                nextGeneration();
                return 0;
            } finally {
                if (!ranAction)
                    breakBarrier();
            }
        }

        // 循环等待,完成、阻塞、中断、超时
        for (;;) {
            try {
                if (!timed)
                    trip.await();
                else if (nanos > 0L)
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
                if (g == generation && ! g.broken) {
                    breakBarrier();
                    throw ie;
                } else {
                    // We're about to finish waiting even if we had not
                    // been interrupted, so this interrupt is deemed to
                    // "belong" to subsequent execution.
                    Thread.currentThread().interrupt();
                }
            }

            if (g.broken)
                throw new BrokenBarrierException();

            if (g != generation)
                return index;

            if (timed && nanos <= 0L) {
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        lock.unlock();
    }

}

private static class Generation {
    boolean broken = false;
}

private void breakBarrier() {
    generation.broken = true;
    count = parties;
    trip.signalAll();
}

更新屏蔽,并唤醒其他阻塞线程

private final ReentrantLock lock = new ReentrantLock();
private final Condition trip = lock.newCondition();

private void nextGeneration() {
    // signal completion of last generation
    trip.signalAll();
    // set up next generation
    count = parties;
    generation = new Generation();
}

public final void signalAll() {
    // 是否独占
    if (!isHeldExclusively())
        throw new IllegalMonitorStateException();
    Node first = firstWaiter;
    if (first != null)
        doSignalAll(first);
}

// 是否独占
protected final boolean isHeldExclusively() {
    return getExclusiveOwnerThread() == Thread.currentThread();
}

private void doSignalAll(Node first) {
    lastWaiter = firstWaiter = null;
    do {
        Node next = first.nextWaiter;
        first.nextWaiter = null;
        transferForSignal(first);
        first = next;
    } while (first != null);
}

final boolean transferForSignal(Node node) {
    // 设置等待状态就绪
    if (!compareAndSetWaitStatus(node, Node.CONDITION, 0))
        return false;
    // 插入到尾节点,获取前一个节点
    Node p = enq(node);
    int ws = p.waitStatus;
    // 唤醒前节点
    if (ws > 0 || !compareAndSetWaitStatus(p, ws, Node.SIGNAL))
        LockSupport.unpark(node.thread);
    return true;
}

等待

public final void await() throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    // 增加当前线程等待
    Node node = addConditionWaiter();
    //  获取剩余状态数量
    int savedState = fullyRelease(node);
    int interruptMode = 0;
    // 是否在阻塞队列
    while (!isOnSyncQueue(node)) {
        // 不在则阻塞当前线程
        LockSupport.park(this);
        if ((interruptMode = checkInterruptWhileWaiting(node)) != 0)
            break;
    }
    // 阻塞队列中
    // 自选获取锁
    if (acquireQueued(node, savedState) && interruptMode != THROW_IE)
        interruptMode = REINTERRUPT;
    if (node.nextWaiter != null) // clean up if cancelled
        unlinkCancelledWaiters();
    if (interruptMode != 0)
        reportInterruptAfterWait(interruptMode);
}

// 增加一个等待节点
private Node addConditionWaiter() {
    Node t = lastWaiter;
    // If lastWaiter is cancelled, clean out.
    if (t != null && t.waitStatus != Node.CONDITION) {
        unlinkCancelledWaiters();
        t = lastWaiter;
    }
    Node node = new Node(Thread.currentThread(), Node.CONDITION);
    if (t == null)
        firstWaiter = node;
    else
        t.nextWaiter = node;
    lastWaiter = node;
    return node;
}

// 去除取消的waiters
private void unlinkCancelledWaiters() {
    Node t = firstWaiter;
    Node trail = null;
    while (t != null) {
        Node next = t.nextWaiter;
        if (t.waitStatus != Node.CONDITION) {
            t.nextWaiter = null;
            if (trail == null)
                firstWaiter = next;
            else
                trail.nextWaiter = next;
            if (next == null)
                lastWaiter = trail;
        }
        else
            trail = t;
        t = next;
    }
}

// 节点在阻塞队列中
final boolean isOnSyncQueue(Node node) {
    if (node.waitStatus == Node.CONDITION || node.prev == null)
        return false;
    if (node.next != null) // If has successor, it must be on queue
        return true;
    return findNodeFromTail(node);
}

尝试释放

final int fullyRelease(Node node) {
    boolean failed = true;
    try {
        int savedState = getState();
        if (release(savedState)) {
            failed = false;
            return savedState;
        } else {
            throw new IllegalMonitorStateException();
        }
    } finally {
        if (failed)
            node.waitStatus = Node.CANCELLED;
    }
}

public final boolean release(int arg) {
    if (tryRelease(arg)) {
        Node h = head;
        if (h != null && h.waitStatus != 0)
            unparkSuccessor(h);
        return true;
    }
    return false;
}

protected final boolean tryRelease(int releases) {
    int c = getState() - releases;
    if (Thread.currentThread() != getExclusiveOwnerThread())
        throw new IllegalMonitorStateException();
    boolean free = false;
    if (c == 0) {
        free = true;
        setExclusiveOwnerThread(null);
    }
    setState(c);
    return free;
}

private void unparkSuccessor(Node node) {
    int ws = node.waitStatus;
    if (ws < 0)
        compareAndSetWaitStatus(node, ws, 0);

    Node s = node.next;
    if (s == null || s.waitStatus > 0) {
        s = null;
        for (Node t = tail; t != null && t != node; t = t.prev)
            if (t.waitStatus <= 0)
                s = t;
    }
    if (s != null)
        LockSupport.unpark(s.thread);
}