Java 并发——深入 ReadWriteLock 原理

131 阅读7分钟

ReadWriteLock 背后维护着一对相互关联的锁,一个用于读,一个用于写。读锁可以被多个读线程并发获取, 只要没有写线程。而写锁不支持此种情况。读写锁,读-读能共存,读-写不能共存,写-写不能共存。

自实现

如果某个线程想要读取资源,只要没有线程正在对该资源进行写操作且没有线程请求对该资源的写操作即可。

public class ReadWriteLock {
	private int readers = 0;
	private int writers = 0;
	private int writeRequests = 0;

	public synchronized void lockRead() throws InterruptedException{
		while(writers > 0 || writeRequests > 0){
			wait();
		}
		readers++;
	}

	public synchronized void unlockRead(){
		readers--;
		notifyAll();
	}

	public synchronized void lockWrite() throws InterruptedException{
		writeRequests++;

		while(readers > 0 || writers > 0){
			wait();
		}
		writeRequests--;
		writers++;
	}

	public synchronized void unlockWrite() throws InterruptedException{
		writers--;
		notifyAll();
	}
}

比较简单的一个读写锁实现,其中有一个问题就是如果写操作很频繁,那么读线程可能会产生"饥饿现象"。 需要注意的是,在两个释放锁的方法(unlockRead,unlockWrite)中,都调用了notifyAll方法, 而不是notify。要解释这个原因,我们可以想象下面一种情形:

如果有线程在等待获取读锁,同时又有线程在等待获取写锁。如果这时其中一个等待读锁的线程被
notify() 方法唤醒,但因为此时仍有请求写锁的线程存在(writeRequests>0),所以被
唤醒的线程会再次进入阻塞状态。然而,等待写锁的线程一个也没被唤醒,就像什么也没发生过一样
(信号丢失)。如果用的是 notifyAll() 方法,所有的线程都会被唤醒,然后判断能否获得其
请求的锁。

用 notifyAll() 方法还有一个好处。如果有多个读线程在等待读锁且没有线程在等待写锁时, 调用 unlockWrite() 方法后,所有等待读锁的线程都能立马成功获取读锁,而不是一次只允许一个。

读写锁重入

上面代码不支持读写锁重入的,当一个已经持有写锁的线程再次请求写锁时,就会被阻塞。原因是已经有一个 写线程了——就是它自己。此外,考虑下面的例子:

  1. thread#1 获得了读锁
  2. thread#2 请求写锁,但因为 thread#1 持有了读锁,所以写锁请求被阻塞
  3. thread#1 再想请求一次读锁,但因为 thread#2 处于请求写锁的状态,所以想再次获取读锁也会被阻塞

所以实现读写锁可重入的思路就是:

  • 保存当前获取读锁的线程,并计数。当前读线程之前如果获取到了读锁,则直接重入,否则, 如果有写线程或有写线程在等待锁,那么获取读锁失败,需要等待
  • 保存当前获取写锁的线程,并计数。如果没有读线程,或没有写线程,那么获取写锁成功。或者同一线程再次 获取写锁。

ReentrantReadWriteLock

类图

ReentrantReadWriteLock

ReentrantReadWriteLock.Sync

数据结构:

// 读锁和写锁计数常量
// 锁状态逻辑上被划分成两部分无符号整数:
// 地位表示排他锁计数(exclusive(writer) lock)
// 高位表示共享锁计数(shared(reader) lock)
// 共享状态 state 为 int 型,32 位,所以 SHARED_SHIFT = 16,各占一半
static final int SHARED_SHIFT   = 16;
// 65536,二进制:10000000000000000
static final int SHARED_UNIT    = (1 << SHARED_SHIFT);
// 65535
static final int MAX_COUNT      = (1 << SHARED_SHIFT) - 1;
// 65535,二进制:1111111111111111
static final int EXCLUSIVE_MASK = (1 << SHARED_SHIFT) - 1;

// 第一个获取读锁的线程,即把 shared count 从 0 改变为 1 的线程
private transient Thread firstReader = null;
private transient int firstReaderHoldCount;

// 上一个线程获取读锁的计数
private transient HoldCounter cachedHoldCounter;

// 当前线程所持有的可重入读锁的计数
private transient ThreadLocalHoldCounter readHolds;

构建 ReentrantReadWriteLock.Sync 时会进行初始化操作

Sync() {
    readHolds = new ThreadLocalHoldCounter();
    setState(getState()); // 默认为 0
}

非公平加锁

ReadLock(读锁)

读锁,实现了 Lock 接口,内部依赖对 AQS 的实现 Sync 类。

// ReadLock 类
// 如果写锁没有被其他线程获取,那么立即返回,否则,阻塞
 public void lock() {
    sync.acquireShared(1);
}

// AQS 类
public final void acquireShared(int arg) {
    if (tryAcquireShared(arg) < 0)
        doAcquireShared(arg);
}

// ReentrantReadWriteLock.Sync 类
// unused = 1
protected final int tryAcquireShared(int unused) {
    /*
     * 1. 如果有其他线程获取了 写锁,失败
     * 2. 
     */
    Thread current = Thread.currentThread();
    int c = getState();
    
    // exclusiveCount(c) => c & 1111111111111111(Sync.EXCLUSIVE_MASK)
    
    // ① 判断是否有其他线程获取了 写锁。锁降级:如果获取写锁的线程可以获取读锁
    if (exclusiveCount(c) != 0 &&
        getExclusiveOwnerThread() != current)
        return -1;
        
    // sharedCount(c) => c >>> 16(Sync.SHARED_SHIFT) (高 16 位表示共享读)
    int r = sharedCount(c);
    
    // ② 核心
    if (!readerShouldBlock() &&
        r < MAX_COUNT &&
        compareAndSetState(c, c + SHARED_UNIT)) {
        if (r == 0) { 
            firstReader = current;
            firstReaderHoldCount = 1;
        } else if (firstReader == current) { 
            firstReaderHoldCount++;
        } else {
            HoldCounter rh = cachedHoldCounter;
            if (rh == null || rh.tid != getThreadId(current))
                cachedHoldCounter = rh = readHolds.get();
            else if (rh.count == 0)
                readHolds.set(rh);
            rh.count++;
        }
        return 1;
    }
    // ③
    return fullTryAcquireShared(current);
}

// NonfairSync 类
final boolean readerShouldBlock() {
    return apparentlyFirstQueuedIsExclusive();
}
// AQS 类
// 如果队列中等待的第一个节点(head 节点不算等待,因为它表示的线程拿到锁了)是 写线程,
// 那么 读线程 应该被 阻塞。目的防止 写线程 产生饥饿现象
final boolean apparentlyFirstQueuedIsExclusive() {
    // 1. 队列未被初始化
    // 1.1 读线程来获取读锁,返回 false
    // 2. 队列被初始化
    // 2.1 队列当前只有一个节点 head,返回 false
    // 2.2 队列不止一个节点 head
    // 2.2.1 第一个等待节点的 nextWaiter 不是 SHARED 模式,返回 true
    Node h, s;
    return (h = head) != null &&
        (s = h.next)  != null &&
        !s.isShared()         &&
        s.thread != null;
}

核心逻辑在上述 tryAcquireShared() 方法注释 ② 标示的 if() 判断逻辑,我们一步一步来解析。 if() 判断条件由三部分组成,并且是 && 关系,也就是只要有一部分不满足,就进入注释 ③ 标示的逻辑代码。

  • !readerShouldBlock()
    上述代码里解释了 readerShouldBlock() 方法逻辑。注意这里还对其结果取反了。
  • r < MAX_COUNT
    ReentrantReadWriteLock.Sync 类的数据结构有说 MAX_COUTN 字段。这里表示 读线程 不能超过 65535 个,否则会报错。
  • compareAndSetState(c, c + SHARED_UNIT)
    c 变量是当前 state 的值。如果此时第一个 读线程 进来,c = 0,c + SHARED_UNIT = 65536,为啥是这样?因为高 16 位表示 共享读,并且计算 读线程 的个数是 c >>> 16。 读线程获取锁,修改 state 值是state += 65536。

当满足上述条件时,表示线程已经拿到读锁了,if() 判断条件的内部逻辑分三步走:

  1. 还没有线程获取读锁,即

     (getState() >>> SHARED_SHIFT) = 0
    

    记录第一个获取读锁的线程,以及它的重入数。

  2. 如果第一步不满足,判断 firstReader 跟当前线程是否同一线程,是,firstReaderHoldCount 增加。

  3. 否则,执行下面逻辑

    HoldCounter rh = cachedHoldCounter; // 当前线程的缓存计数
    // rh 还未被初始化,即没有第二个线程进来获取读锁
    if (rh == null || rh.tid != getThreadId(current))
        // 初始化 cachedHoldCounter
        cachedHoldCounter = rh = readHolds.get();
    else if (rh.count == 0)
        readHolds.set(rh);
    rh.count++;
    

如果注释 ② 的 if() 判断没有满足,即线程不能获取读锁,或已经达到最高读线程数,或 CAS 失败,那么都会进入注释 ③ 的方法 fullTryAcquireShared()。

// ReentrantReadWriteLock.Sync 类
final int fullTryAcquireShared(Thread current) {
    HoldCounter rh = null;
    // 死循环
    for (;;) {
        int c = getState();
        
        // 是否有写线程获取多写锁。锁降级
        if (exclusiveCount(c) != 0) {
        
            if (getExclusiveOwnerThread() != current)
                // 有写线程获取到写锁,其他线程来获取读锁,直接返回 -1,进队列
                return -1;
                
        } else if (readerShouldBlock()) { // 没有写锁(写锁可能被释放)
 			  // head(读线程) -> node1(写线程,等待)
            if (firstReader == current) {
                // assert firstReaderHoldCount > 0;
            } else {
                if (rh == null) {
                    rh = cachedHoldCounter;
                    if (rh == null || rh.tid != getThreadId(current)) {
                        rh = readHolds.get();
                        if (rh.count == 0)
                            readHolds.remove();
                    }
                }
                
                if (rh.count == 0)
                    return -1;
            }
        }
        
        // ③ 计数设计
        if (sharedCount(c) == MAX_COUNT)
            throw new Error("Maximum lock count exceeded");
        if (compareAndSetState(c, c + SHARED_UNIT)) {
            if (sharedCount(c) == 0) {
                firstReader = current;
                firstReaderHoldCount = 1;
            } else if (firstReader == current) {
                firstReaderHoldCount++;
            } else {
                if (rh == null)
                    rh = cachedHoldCounter;
                if (rh == null || rh.tid != getThreadId(current))
                    rh = readHolds.get();
                else if (rh.count == 0)
                    readHolds.set(rh);
                rh.count++;
                cachedHoldCounter = rh; // cache for release
            }
            return 1;
        }
    }
}

// AQS 类
// 获取读锁
public final void acquireShared(int arg) {
    if (tryAcquireShared(arg) < 0)
        doAcquireShared(arg);
}

private void doAcquireShared(int arg) {
    // 获取读锁失败,入队列
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        boolean interrupted = false;
        // 循环
        for (;;) {
            // 前继节点
            final Node p = node.predecessor();
            if (p == head) { // 前继节点为 head
                int r = tryAcquireShared(arg);
                if (r >= 0) { // 获取到读锁
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    if (interrupted)
                        selfInterrupt();
                    failed = false;
                    return;
                }
            }
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt())
                interrupted = true;
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

我们理下进入 fullTryAcquireShared() 方法的各种情况:

  1. 写锁降级,即 thread1 获取到写锁,且准备再获取读锁

    • thread1 获取到读锁之前,线程 B 准备获取写锁,此时 thread1 能正常获取读锁

      队列

      进入注释③ 的代码

    • thread1 获取到读锁之前,没有其他线程准备获取写锁

      Sync

  2. thread1 获取到写锁,thread2 来获取读锁

    Read

WriteLock(写锁)

整个过程同 ReentrantLock 类的获取锁逻辑差不多。

小结

  • 写锁可以降级为读锁,读锁不能升级为写锁
  • 同步队列中第一个等待的节点是写线程,则不能获取读锁,需要排队
  • 读线程获取到读锁,需要传递后继节点