Semaphore原理,CountDownLatch原理,CyclicBarrier原理

88 阅读2分钟

Semaphore原理

定义

得知:使用AQS作为模板类,然后使用其共享锁机制,实现了公平锁和非公平锁来完成Semaphore信号量语义。获取permit的acquire(int permits)操作、释放permit的release(int permits)操作,均是由sync类来完成的

public class Semaphore implements java.io.Serializable {
    private final Sync sync;
    static final class FairSync extends Sync {}
    static final class NonfairSync extends Sync {}
    abstract static class Sync extends AbstractQueuedSynchronizer {}
    
    public void acquire(int permits) throws InterruptedException {
        if (permits < 0) throw new IllegalArgumentException();
        sync.acquireSharedInterruptibly(permits);
    }
    
    public void release(int permits) {
        if (permits < 0) throw new IllegalArgumentException();
        sync.releaseShared(permits);
    }
}

acquireSharedInterruptibly方法原理

该代码由AQS完成,我们不做过多赘述,主要看这里的Interruptibly的意思,此时相当于响应线程的中断。最后还是调用子类的tryAcquireShared模板方法,让子类实现自己获取信号量的机制。

public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer{
    public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        if (tryAcquireShared(arg) < 0)
            doAcquireSharedInterruptibly(arg);
    }
}

我们先来看公平锁的实现:

state变量用于标识permit值。

static final class FairSync extends Sync {
    // 构造器用于初始化父类AQS的state变量
    FairSync(int permits) {
        super(permits);
    }
​
    protected int tryAcquireShared(int acquires) {
        for (;;) {
            // 如果有线程在AQS的队列中排队,那么返回-1,将由AQS完成阻塞操作
            if (hasQueuedPredecessors())
                return -1;
            int available = getState();
            int remaining = available - acquires;
            // 如果此时有可用的多余的信号量,那么进行CAS操作,如果失败,那么返回剩下的资源数。如果此时CAS成功,那么返回的资源数就为当前值,有可能为0或者大于0
            if (remaining < 0 ||
                compareAndSetState(available, remaining))
                return remaining;
        }
    }
}

非公平锁的实现:

static final class NonfairSync extends Sync {
    NonfairSync(int permits) {
        super(permits);
    }
    
    // 由父类Sync来完成调用
    protected int tryAcquireShared(int acquires) {
        return nonfairTryAcquireShared(acquires);
    }
}
​
// sync类实现方法
final int nonfairTryAcquireShared(int acquires) {
    for (;;) {
        // 非公平锁直接CAS抢即可,直到可用资源数小于0
        int available = getState();
        int remaining = available - acquires;
        if (remaining < 0 ||
            compareAndSetState(available, remaining))
            return remaining;
    }
}

releaseShared方法原理

AQS实现方法,我们可以看到当模板方法tryReleaseShared,由子类完成释放后,那么将会调用doReleaseShared方法唤醒后面等待的线程。

public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer{
    public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }
        return false;
    }
}

直接看Sync类的实现:

protected final boolean tryReleaseShared(int releases) {
    for (;;) {
        // 直接通过CAS操作对state变量+1即可
        int current = getState();
        int next = current + releases;
        if (next < current)
            throw new Error("Maximum permit count exceeded");
        if (compareAndSetState(current, next))
            return true;
    }
}

CountDownLatch原理

定义

是一个同步器,用于一个或者多个线程等待其他线程完成一组操作,原理如下:

1、AQS的state变量用于表示操作个数

2、AQS的共享锁机制完成唤醒

3、等待锁的线程使用acquireShared方法获取共享锁等待

4、操作线程使用releaseShared方法用于唤醒等待共享锁的线程

构造器原理

从构造器中得出,count值和核心操作均由内部同步器类Sync完成

public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}

CountDown方法原理

public void countDown() {
    sync.releaseShared(1);
}

Await方法原理

public boolean await(long timeout, TimeUnit unit)
    throws InterruptedException {
    return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}

tryAcquireSharedNanos方法原理

该方法由AQS实现,可以看到方法中抛出了InterruptedException中断异常,由此可见该方法响应了线程中断,但是核心操作还是由子类来实现tryAcquireShared(arg)来完成共享锁的获取操作。

public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer{
    public final boolean tryAcquireSharedNanos(int arg, long nanosTimeout)
        throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
        return tryAcquireShared(arg) >= 0 ||
            doAcquireSharedNanos(arg, nanosTimeout); // 该方法在后面说AQS时完成讲解
    }
}

tryAcquireShared方法实现

就是看变量是否为0,如果为0,那么无条件返回1,此时将会直接获取到共享锁。

protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
}

releaseShared方法

该方法由AQS来实现,可以看到通过子类完成tryReleaseShared方法释放共享锁,如果释放成功,那么直接调用doReleaseShared方法完成等待获取共享锁的线程,获取共享锁。

public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer{
    public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }
        return false;
    }
}

tryReleaseShared方法实现过程

protected boolean tryReleaseShared(int releases) {
    for (;;) {
        // 获取当前state值,代表了完成的操作个数
        int c = getState();
        if (c == 0)
            return false;
        // 计算更新值,CAS原子性的修改即可
        int nextc = c-1;
        if (compareAndSetState(c, nextc))
            // 若修改成功,那么判断当前线程是不是最后一个完成操作的线程,如果是,那么返回true,此时唤醒所有等待共享锁的线程
            return nextc == 0;
    }
}

CyclicBarrier原理

核心数据结构与构造器原理

public class CyclicBarrier {
    // 该类用于reset后复用该结构,每一次的party都会生成一个新的该类的实例
    private static class Generation {
        boolean broken = false; // 当前party有没有被强制中断
    }
    private final ReentrantLock lock = new ReentrantLock();
    // 用于阻塞线程的条件变量:有未到party的线程,那么等待在该条件变量上
    private final Condition trip = lock.newCondition();
    // 参与party的线程数
    private final int parties;
    // 当所有的线程都参与到了party中后回调的方法
    private final Runnable barrierCommand;
    // 当前party
    private Generation generation = new Generation();
    // 还未到party的线程数
    private int count;
    
    public CyclicBarrier(int parties, Runnable barrierAction) {
        if (parties <= 0) throw new IllegalArgumentException();
        this.parties = parties;
        this.count = parties;
        this.barrierCommand = barrierAction;
    }
}

await方法

public int await() throws InterruptedException, BrokenBarrierException {
    try {
        return dowait(false, 0L);
    } catch (TimeoutException toe) {
        throw new Error(toe); // cannot happen
    }
}
​
private int dowait(boolean timed, long nanos) throws InterruptedException, BrokenBarrierException,
TimeoutException {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        final Generation g = generation;  // 保存当前party时的Generation快照,更新后将不会影响这里的实例
        if (g.broken)
            throw new BrokenBarrierException();
        // 有中断的线程混入其中,干掉其他参会人重新开始,此时并没有改变party的Generation
        if (Thread.interrupted()) {
            breakBarrier();
            throw new InterruptedException();
        }
        int index = --count;
        if (index == 0) {  // 最后一个到达party的线程,负责唤醒所有阻塞在条件变量上的线程,然后回调barrierCommand(若正常完成,那么不需要手动调用reset,因为这里调用了nextGeneration)
            boolean ranAction = false;
            try {
                final Runnable command = barrierCommand;
                if (command != null)
                    command.run();
                ranAction = true;
                nextGeneration(); // 进入下一个party
                return 0;
            } finally {
                // barrierCommand回调方法发生了异常,那么设置broken标志位
                if (!ranAction)
                    breakBarrier();
            }
        }
        // 循环等待最后一个参与party的线程唤醒自己,或者?被中断。或者?等待超时
        for (;;) {
            try {
                if (!timed)
                    trip.await();
                else if (nanos > 0L)
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
                if (g == generation && ! g.broken) {
                    breakBarrier();
                    throw ie;
                } else {
                    Thread.currentThread().interrupt();
                }
            }
            if (g.broken)
                throw new BrokenBarrierException();
            if (g != generation)
                return index;
            if (timed && nanos <= 0L) {
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        lock.unlock();
    }
}

reset方法

public void reset() {
    final ReentrantLock lock = this.lock;
    lock.lock();
    try {
        breakBarrier(); // 将所有参与party的线程唤醒
        nextGeneration();  // 生成下一代
    } finally {
        lock.unlock();
    }
}

breakBarrier方法

private void breakBarrier() {
    generation.broken = true;
    count = parties;
    trip.signalAll();
}

nextGeneration方法

private void nextGeneration() {
    trip.signalAll();
    count = parties;
    generation = new Generation(); // 生成了下一代party实例
}

\