AQS框架的定制化能力

251 阅读8分钟

前言

AQS(AbstractQueuedSynchronizer)是java并发编程中的一个基础类,许多juc下面的类都基于AQS进行扩展,实现并发同步的功能,本文首先将简单介绍下AQS类,接着介绍几个使用AQS框架的定制化能力进行并发控制的工具类,最后进行总结和扩展。

AQS简介

在并发控制中,会出现很多不满足临界条件(如抢锁失败)的线程,AQS提供了一种组织机制来组织,调配这些不满足临界条件的线程,用一个双链表将它们组织起来,从而进行后续的唤醒等操作。

结构

static final class Node {}

private volatile int state;

protected final int getState() {
    return state;
}
protected final void setState(int newState) {
    state = newState;
}

protected final boolean compareAndSetState(int expect, int update) {
    // See below for intrinsics setup to support this
    return unsafe.compareAndSwapInt(this, stateOffset, expect, update);
}

Node就是双向链表的节点,state其实就是下面几个类操作、判断的变量,也是AQS类,以及使用AQS进行扩展的类的核心,这里compareAndSetState是一个原子操作,使用了Unsafe类来保证原子性。

protected boolean tryAcquire(int arg) {
        throw new UnsupportedOperationException();
    }
protected boolean tryRelease(int arg) {
    throw new UnsupportedOperationException();
}
protected int tryAcquireShared(int arg) {
    throw new UnsupportedOperationException();
}
protected boolean tryReleaseShared(int arg) {
    throw new UnsupportedOperationException();
}

这4个方法可以由子类重写来实现逻辑,从而基于AQS实现自己的同步器,下面介绍几个基于AQS实现的同步器。

基于AQS实现的同步器

主要介绍 ReentrantLock,CountDownLatch,Semaphore 这三个类,它们都有一个共性,就是它们内部都定义了一个继承自AQS的Sync类来进行操作

ReentrantLock

ReentrantLock是一个常用的锁,也是基于AQS扩展的一个同步器。

初始化和结构

public ReentrantLock() {
        sync = new NonfairSync();
    }
public ReentrantLock(boolean fair) {
    sync = fair ? new FairSync() : new NonfairSync();
}

private final Sync sync;

abstract static class Sync extends AbstractQueuedSynchronizer {
    ...
}

可以看到,其实有公平和非公平两种实现

static final class NonfairSync extends Sync {
        ...
    }

    /**
     * Sync object for fair locks
     */
    static final class FairSync extends Sync {
        ...
    }

Lock

public void lock() {
        sync.lock();
    }
// // AQS中
public final void acquire(int arg) {
    if (!tryAcquire(arg) &&
        acquireQueued(addWaiter(Node.EXCLUSIVE), arg))
        selfInterrupt();
}
// 非公平锁的实现
final void lock() {
    if (compareAndSetState(0, 1))
        setExclusiveOwnerThread(Thread.currentThread());
    else
        acquire(1);
}
protected final boolean tryAcquire(int acquires) {
    return nonfairTryAcquire(acquires);
}
final boolean nonfairTryAcquire(int acquires) {
    final Thread current = Thread.currentThread();
    int c = getState();
    if (c == 0) {
        if (compareAndSetState(0, acquires)) {
            setExclusiveOwnerThread(current);
            return true;
        }
    }
    else if (current == getExclusiveOwnerThread()) {
        int nextc = c + acquires;
        if (nextc < 0) // overflow
            throw new Error("Maximum lock count exceeded");
        setState(nextc);
        return true;
    }
    return false;
}
// 公平锁的实现
final void lock() {
    acquire(1);
}
// 具体实现
protected final boolean tryAcquire(int acquires) {
    final Thread current = Thread.currentThread();
    int c = getState();
    if (c == 0) {
        if (!hasQueuedPredecessors() &&
            compareAndSetState(0, acquires)) {
            setExclusiveOwnerThread(current);
            return true;
        }
    }
    else if (current == getExclusiveOwnerThread()) {
        int nextc = c + acquires;
        if (nextc < 0)
            throw new Error("Maximum lock count exceeded");
        setState(nextc);
        return true;
    }
    return false;
}
// 设置锁的拥有线程
protected final void setExclusiveOwnerThread(Thread thread) {
    exclusiveOwnerThread = thread;
}
非公平锁

总之就是先判断state是否为0,是的话,进行强锁,否则,由锁的持有者是否是本线程来增加锁的重入数

公平锁

流程和非公平锁基本是类似的,只是在state为0时,抢锁前会先判断阻塞队列里有没有节点在自己前面,体现了公平

总之如果没有抢到锁,最终会进入AQS的 acquireQueued(addWaiter(Node.EXCLUSIVE), arg)) 进行阻塞排队

首先调了addWaiter(Node.EXCLUSIVE) 将当前线程封装成一个 Node 节点,并添加到等待队列的尾部。

接下来除非这个线程是第一个节点并且调用 tryAcquire() 成功,否则,线程将调用 park() 方法让自己挂起。

tryLock

public boolean tryLock() {
    return sync.nonfairTryAcquire(1);
}
final boolean nonfairTryAcquire(int acquires) {
    ...
}

tryLock直接调的是nonfairTryAcquire,可见tryLock其实不管是否是公平锁,并且tryLock也没有后续的AQS操作了,直接返回true or false

unlock

public void unlock() {
    sync.release(1);
}
// AQS中
public final boolean release(int arg) {
    if (tryRelease(arg)) {
        Node h = head;
        if (h != null && h.waitStatus != 0)
            unparkSuccessor(h);
        return true;
    }
    return false;
    }
// 具体实现
protected final boolean tryRelease(int releases) {
    int c = getState() - releases;
    if (Thread.currentThread() != getExclusiveOwnerThread())
        throw new IllegalMonitorStateException();
    boolean free = false;
    if (c == 0) {
        free = true;
        setExclusiveOwnerThread(null);
    }
    setState(c);
    return free;
}

如果unlock成功了,会调用unparkSuccessor(h)方法来唤醒等待队列中的下一个节点所代表的线程,这里唤醒调用的是unpark方法,之前调用acquireQueued被挂起的第一个线程就可以被唤醒了

所以说,lock之后一定要unlock,不然那些挂起的线程永远不会被唤醒

可以看出,ReentrantLock通过重写AQS类的 tryAcquire 和 tryRelease 来定义了挂起和唤醒线程的条件。

CountDownLatch

CountDownLatch是一个常用的多线程同步器,可以用于需要等多个线程结束后主线程再进行某些操作的场景。

使用示例

CountDownLatch countDownLatch = new CountDownLatch(instanceIds.size());
instanceIds.forEach(instanceId -> {
    executorService.execute(() -> {
        try {
                xxx
            }finally {
                countDownLatch.countDown()
            }    });
});
countDownLatch.await();

如上代码,每个instanceId都初始化一个线程进行某些操作,等所有线程都完成了之后,主线程才可以向下运行。

结构

private static final class Sync extends AbstractQueuedSynchronizer {
    private static final long serialVersionUID = 4982264981922014374L;

    Sync(int count) {
        setState(count);
    }

    int getCount() {
        return getState();
    }

    protected int tryAcquireShared(int acquires) {
        ...
    }

    protected boolean tryReleaseShared(int releases) {
        ...
    }
}

private final Sync sync;

可以看到也是定义了一个继承自AQS的Sync来进行操作的

初始化

public CountDownLatch(int count) {
    if (count < 0) throw new IllegalArgumentException("count < 0");
    this.sync = new Sync(count);
}
Sync(int count) {
    setState(count);
}

其实就是设置了AQS的state变量

countDown

public void countDown() {
    sync.releaseShared(1);
}
// AQS中
public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}
protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    for (;;) {
        int c = getState();
        if (c == 0)
            return false;
        int nextc = c-1;
        if (compareAndSetState(c, nextc))
            return nextc == 0;
    }
}

其实就是把state减1,这里如果state更新为0了,会调用AQS的doReleaseShared方法唤醒挂起的线程。

可以看到,也是通过重写了 AQS类的tryReleaseShared方法来定义自己的唤醒线程逻辑和条件,state为0时,才唤起挂起的线程。

await

public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}
// AQS中
public final void acquireSharedInterruptibly(int arg)
        throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    if (tryAcquireShared(arg) < 0)
        doAcquireSharedInterruptibly(arg);
}
protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
}

说白了,主线程await就是等这个state被减为0,也就是所有线程都完成工作,否则就调用doAcquireSharedInterruptibly方法把自己挂起

这里也是通过重写AQS类的 tryAcquireShared 来定义自己的挂起条件,即只要 state 不为 0,均挂起。

Semaphore

信号量也是常用的同步器,可以控制进入临界区的线程的数量,也是基于AQS的能力进行扩展的。

使用示例

public class SimpleSemaphoreExample {
    public static void main(String[] args) throws InterruptedException {
        // 初始化一个 Semaphore,允许最多2个线程同时执行
        Semaphore semaphore = new Semaphore(2);
        
        for (int i = 1; i <= 5; i++) {
            new Worker(semaphore, "Worker " + i).start();
        }
    }
    static class Worker extends Thread {
        private Semaphore semaphore;
        private String name;

        public Worker(Semaphore semaphore, String name) {
            this.semaphore = semaphore;
            this.name = name;
        }
        @Override
        public void run() {
            try {
                // 获取许可
                semaphore.acquire();
                System.out.println(name + " 开始执行任务.");
                
                // 模拟执行任务
                TimeUnit.SECONDS.sleep(2);
                
                System.out.println(name + " 任务完成.");
            } catch (InterruptedException e) {
                System.out.println(name + " 被中断.");
            } finally {
                // 释放许可
                semaphore.release();
            }
        }
    }
}

这是一个简单的例子,表明其作用,从运行结果可以看出Semaphore的作用就是控制线程同时运行的数量

结构和初始化

private final Sync sync;

    /**
     * Synchronization implementation for semaphore.  Uses AQS state
     * to represent permits. Subclassed into fair and nonfair
     * versions.
     */
    abstract static class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 1192457210091910933L;

        Sync(int permits) {
            setState(permits);
        }

        final int nonfairTryAcquireShared(int acquires) {
            ...
        }

        protected final boolean tryReleaseShared(int releases) {
            ...
        }

        final void reducePermits(int reductions) {
            ...
        }

        final int drainPermits() {
            ...
        }
    }

    /**
     * NonFair version
     */
    static final class NonfairSync extends Sync {
        private static final long serialVersionUID = -2694183684443567898L;

        NonfairSync(int permits) {
            super(permits);
        }

        protected int tryAcquireShared(int acquires) {
            return nonfairTryAcquireShared(acquires);
        }
    }

    /**
     * Fair version
     */
    static final class FairSync extends Sync {
        private static final long serialVersionUID = 2014338818796000944L;

        FairSync(int permits) {
            super(permits);
        }

        protected int tryAcquireShared(int acquires) {
...
        }
    }

同样有一个继承AQS类的Sync

初始化其实还是setState

Acquire

public void acquire() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}
// AQS中
public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
    if (Thread.interrupted())
        throw new InterruptedException();
    if (tryAcquireShared(arg) < 0)
        // 阻塞了
        doAcquireSharedInterruptibly(arg);
}
// 具体实现
final int nonfairTryAcquireShared(int acquires) {
    for (;;) {
        int available = getState();
        int remaining = available - acquires;
        if (remaining < 0 ||
            compareAndSetState(available, remaining))
            return remaining;
    }
}

这里acquire通过的条件其实是剩余的remaining大于等于0,感觉和限流有点像

通过重写AQS类的tryAcquireShared,定义了线程挂起的条件为 state - acquires(需要的许可数)小于 0

Release

public void release() {
    sync.releaseShared(1);
}
// AQS中
public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) {
        doReleaseShared();
        return true;
    }
    return false;
}
// 具体实现
protected final boolean tryReleaseShared(int releases) {
    for (;;) {
        int current = getState();
        int next = current + releases;
        if (next < current) // overflow
            throw new Error("Maximum permit count exceeded");
        if (compareAndSetState(current, next))
            return true;
    }
}

具体实现其实就是尝试加state的值,成功了之后唤醒其它节点

通过重写AQS类的 tryReleaseShared 定义了唤醒其它线程的条件为 只要 state 成功增加了即可。

总结

可以看出,以上三个类,归根结底,都是靠操作和判断 state 这个变量,通过重写AQS类的方法,来设置挂起和唤醒线程的条件,从而达到对线程进行阻塞和唤醒的效果。

线程不挂起的条件
ReentrantLock把state从0设置为1,或者增大state
Semaphorestate-acquires>=0
CountDownLatchstate=0

不同的继承类,通过重写 tryAcquire , tryRelease , tryAcquireShared 和 tryReleaseShared 方法,实现自己定义的同步器,理解了这个原理,我们自己也可以编写符合自己需求的某种场景的同步器,比如编写一个不可重入的独占锁。

下面的代码通过重写AQS的 tryAcquire 和 tryRelease 定义了只有当 state 为 0,且原子性的将 state 设置为 1才算抢到锁,这是一个不可重入的独占锁。

private static class Sync extends AbstractQueuedSynchronizer {
    protected boolean isHeldExclusively() {
        return getState() == 1;
    }
    public boolean tryAcquire(int acquires) {
        if (compareAndSetState(0, 1)) {
            setExclusiveOwnerThread(Thread.currentThread());
            return true;
        }
        return false;
    }
    protected boolean tryRelease(int releases) {
        if (getState() == 0) {
            throw new IllegalMonitorStateException();
        }
        setExclusiveOwnerThread(null);
        setState(0);
        return true;
    }
}