AQS内部原理和CountLatch原理分析

126 阅读5分钟

AQS内部原理分析

核心

  • 状态
  • 队列
  • 期望协作工具类实现锁的获取和释放

state状态

private volatile int state;

而 state 的含义根据具体实现类的作用不同而表示不同的含义

比如说在信号量里面,state 表示的是剩余许可证的数量。如果我们最开始把 state 设置为 10,这就代表许可证初始一共有 10 个,然后当某一个线程取走一个许可证之后,这个 state 就会变为 9,所以信号量的 state 相当于是一个内部计数器。

再比如,在 CountDownLatch 工具类里面,state 表示的是需要倒数的数量。一开始我们假设把它设置为 5,当每次调用 CountDown 方法时,state 就会减 1,一直减到 0 的时候就代表这个门闩被放开。

在 ReentrantLock 中它表示的是锁的占有情况,最开始是 0,表示没有任何线程占有锁;如果 state 变成 1,则就代表这个锁已经被某一个线

程所持有了。

那为什么还会变成 2、3、4 呢?为什么会往上加呢?因为 ReentrantLock 是可重入的,同一个线程可以再次拥有这把锁就叫重入。如果这个锁被同一个线程多次获取,那么 state 就会逐渐的往上加,state的值表示重入的次数。在释放的时候也是逐步递减,比如一开始是 4,释放一次就变成了 3,再释放一次变成了 2,这样进行的减操作,即便是减到 2 或者 1 了,都不代表这个锁是没有任何线程持有,只有当它减到 0 的时候,此时恢复到最开始的状态了,则代表现在没有任何线程持有这个锁了。所以,state 等于 0 表示锁不被任何线程所占有,代表这个锁当前是处于释放状态的,其他线程此时就可以来尝试获取了。

这就是 state 在不同类中不同含义的一个具体表现。我们举了三个例子,如果未来有新的工具要利用到AQS,它一定也需要利用 state,为这个类表示它所需要的业务逻辑和状态。

volatile 本身并不足以保证线程安全

我们举两个和 state 相关的方法,分别是 compareAndSetState 及 setState,它们的实现已经由 AQS去完成了,也就是说,我们直接调用这两个方法就可以对 state 进行线程安全的修改。下面就来看一下这两个方法的源码是怎么实现的。

compareAndSetState

 protected final boolean compareAndSetState(int expect, int update) {
     //利用Unsage里面的CAS,利用CPU指令的原子性保证这个操作的原子性
     //跟原子类保证线程安全的原理一致
     return unsafe.compareAndSwapInt(this, stateOffset, expect, update);
 }

setState

protected final void setState(int newState) {
    //volatile对基本类型变量直接赋值时候可以保证线程安全
    state = newState;
}

FIFO队列

这个队列是双端链表,看看Doug Lea大佬给的图示:

image.png head:当前持有锁的线程,之后的都被阻塞

获取/释放方法

ReentrantLock 中的 lock 方法就是其中一个 “获取方法”,执行时,如果发现 state 不等于 0 且当前线程不是持有锁的线程,那么就代表这个锁已经被其他线程所持有了。这个时候,当然就获取不到锁,于是就让该线程进入阻塞状态。

Semaphore 中的 acquire 方法就是其中一个 “获取方法”,作用是获取许可证,此时能不能获取到这个许可证也取决于 state 的值。如果 state 值是正数,那么代表还有剩余的许可证,数量足够的话,就可以成功获取;但如果 state 是 0,则代表已经没有更多的空余许可证了,此时这个线程就获取不到许可证,会进入阻塞状态,所以这里同样也是和 state 的值相关的

CountDownLatch 获取方法就是 await 方法(包含重载方法),作用是 “等待,直到倒数结束”。执行 await 的时候会判断 state 的值,如果 state 不等于 0,线程就陷入阻塞状态,直到其他线程执行倒数方法把 state 减为 0,此时就代表现在这个门闩放开了,所以之前阻塞的线程就会被唤醒。

在 Semaphore 信号量里面,释放就是 release 方法(包含重载方法),release() 方法的作用是去释放一个许可证,会让 state 加 1;而在 CountDownLatch 里面,释放就是 countDown 方法,作用是倒数一个数,让 state 减 1

利用AQS写线程协作类

  1. 在内部写一个Sync类,Sync继承AQS
  2. 在Sync里根据是否独占重写对应方法
  3. 实现获取/释放的相关方法
protected boolean tryAcquire(int arg) {
    throw new UnsupportedOperationException();
}
protected int tryAcquireShared(int arg) {
    throw new UnsupportedOperationException();
}
protected boolean tryRelease(int arg) {
    throw new UnsupportedOperationException();
}
protected boolean tryReleaseShared(int arg) {
    throw new UnsupportedOperationException();
}

必须要重写这几个方法不然抛出异常

CountDownLatch

public class CountDownLatch {
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;
        
        //构造方法调用的就是这个,设置状态
        Sync(int count) {
            setState(count);
        }
​
        int getCount() {
            return getState();
        }
​
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }
​
        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }
    
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }
    
    
    protected final void setState(int newState) {
        state = newState;
    }
    
    protected final int getState() {
        return state;
    }
}

getCount()

public long getCount() {
    return sync.getCount();
}

CountDown

public void countDown() {
    sync.releaseShared(1);
}

调用AQS releaseShared

public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}
protected boolean tryReleaseShared(int arg) {
    throw new UnsupportedOperationException();
}
private void doReleaseShared() {
    /*
     * Ensure that a release propagates, even if there are other
     * in-progress acquires/releases.  This proceeds in the usual
     * way of trying to unparkSuccessor of head if it needs
     * signal. But if it does not, status is set to PROPAGATE to
     * ensure that upon release, propagation continues.
     * Additionally, we must loop in case a new node is added
     * while we are doing this. Also, unlike other uses of
     * unparkSuccessor, we need to know if CAS to reset status
     * fails, if so rechecking.
     */
    for (;;) {
        Node h = head;
        if (h != null && h != tail) {
            int ws = h.waitStatus;
            if (ws == Node.SIGNAL) {
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                    continue;            // loop to recheck cases
                unparkSuccessor(h);
            }
            else if (ws == 0 &&
                     !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                continue;                // loop on failed CAS
        }
        if (h == head)                   // loop if head changed
            break;
    }
}

await

public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}
public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}
private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        for (;;) {
            final Node p = node.predecessor();
            if (p == head) {
                int r = tryAcquireShared(arg);
                if (r >= 0) {
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt())
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}
protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
}

总结

线程调用CDL的await(),尝试获取共享锁,一开始通常获取不到,于是线程被阻塞,“共享锁”可获取到的条件是"锁计数器"为0,而 “锁计数器” 的初始值为 count,当每次调用 CountDownLatch 对象的countDown 方法时,也可以把 “锁计数器” -1。通过这种方式,调用 count 次 countDown 方法之后,“锁计数器” 就为 0 了,于是之前等待的线程就会继续运行了,并且此时如果再有线程想调用 await 方法时也会被立刻放行,不会再去做任何阻塞操作了