ReentrantReadWriteLock源码分析

91 阅读3分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第9天,点击查看活动详情

类的继承关系

ReentrantReadWriteLock实现了ReadWriterLock接口,同时还实现了Serializable,ReadWriterLock接口中定义了读写锁的规范。

public class ReentrantReadWriteLock
        implements ReadWriteLock, java.io.Serializable 

类的内部组成

类的内部共有5个内部类 分别为Sync,NonfairSync,FairSync,ReadLock,WriterLock 前三个和ReentrantLock类似

image.png

后两个类则实现了Lock接口

image.png

下面开始逐步分析这5个类

Sync

abstract static class Sync extends AbstractQueuedSynchronizer{
    private static final long serialVersionUID = 6317671515068378041L;

    
    //高16位为读锁,低16位为写锁
    static final int SHARED_SHIFT   = 16;
    //读锁单位
    static final int SHARED_UNIT    = (1 << SHARED_SHIFT);
    //读锁最大数量
    static final int MAX_COUNT      = (1 << SHARED_SHIFT) - 1;
    //写锁最大数量
    static final int EXCLUSIVE_MASK = (1 << SHARED_SHIFT) - 1;

    //表示占有读锁的线程数量
    static int sharedCount(int c)    { return c >>> SHARED_SHIFT; }
    //表示占有写锁的线程数量
    static int exclusiveCount(int c) { return c & EXCLUSIVE_MASK; }

    /**
     * 
     * 这是一个计数器的内部类
     */
    static final class HoldCounter {
        //计数
        int count = 0;
        // 获得当前线程的tid值
        final long tid = getThreadId(Thread.currentThread());
    }

    /**
     * 本地线程计数器类
     */
    static final class ThreadLocalHoldCounter
        extends ThreadLocal<HoldCounter> {
        //重写了初始化的方法
        public HoldCounter initialValue() {
            return new HoldCounter();
        }
    }

  
    // 本地线程计数器
    private transient ThreadLocalHoldCounter readHolds;

    //缓存计数器
    private transient HoldCounter cachedHoldCounter;

    //第一个读线程
    private transient Thread firstReader = null;
    //第一个读线程的计数
    private transient int firstReaderHoldCount;

    Sync() {
         // 本地线程计数器
        readHolds = new ThreadLocalHoldCounter();
        //设置AQS的状态
        setState(getState()); // ensures visibility of readHolds
    }

    

    abstract boolean readerShouldBlock();

    
    abstract boolean writerShouldBlock();

  
    //释放写锁资源
    protected final boolean tryRelease(int releases) {
        //判断是否为独占线程
        if (!isHeldExclusively())
            throw new IllegalMonitorStateException();
        //计算释放后的写锁数量
        int nextc = getState() - releases;
        //是否释放成功
        boolean free = exclusiveCount(nextc) == 0;
        if (free)
            //设置独占线程为空
            setExclusiveOwnerThread(null);
        //设置AQS状态
        setState(nextc);
        return free;
    }

    //获取写锁
    protected final boolean tryAcquire(int acquires) {
        //获取当前线程
        Thread current = Thread.currentThread();
        //获取AQS状态
        int c = getState();
        //写线程数量
        int w = exclusiveCount(c);
        if (c != 0) {
            // 写线程数量为0或者当前线程没有占有独占资源
            if (w == 0 || current != getExclusiveOwnerThread())
                return false;
            // 判断是否超过最高写线程数量
            if (w + exclusiveCount(acquires) > MAX_COUNT)
                throw new Error("Maximum lock count exceeded");
            // 设置AQS状态
            setState(c + acquires);
            return true;
        }
        //写线程是否应该被阻塞
        if (writerShouldBlock() ||
            !compareAndSetState(c, c + acquires))
            return false;
        //设置独占线程
        setExclusiveOwnerThread(current);
        return true;
    }

    //读锁线程释放锁
    protected final boolean tryReleaseShared(int unused) {
        Thread current = Thread.currentThread();
        if (firstReader == current) {
            // 读线程占用的资源数为1
            if (firstReaderHoldCount == 1)
                firstReader = null;
            else //减少占用的资源
                firstReaderHoldCount--;
        } else {
            //获取当前线程计数器
            HoldCounter rh = cachedHoldCounter;
            if (rh == null || rh.tid != getThreadId(current))
                //获取当前线程对应的计算器
                rh = readHolds.get();
            //获取计数
            int count = rh.count;
            if (count <= 1) {
                readHolds.remove();
                if (count <= 0)
                    throw unmatchedUnlockException();
            }
            --rh.count;
        }
        for (;;) {
            int c = getState();
            int nextc = c - SHARED_UNIT;
            if (compareAndSetState(c, nextc))
                return nextc == 0;
        }
    }

    private IllegalMonitorStateException unmatchedUnlockException() {
        return new IllegalMonitorStateException(
            "attempt to unlock read lock, not locked by current thread");
    }
    
    //读锁线程获取读锁
    protected final int tryAcquireShared(int unused) {
       
        Thread current = Thread.currentThread();
        int c = getState();
        if (exclusiveCount(c) != 0 &&
            getExclusiveOwnerThread() != current)
            return -1;
        int r = sharedCount(c);
        //写线程数不为0并且占有资源的不是当前线程
        if (!readerShouldBlock() &&
            r < MAX_COUNT &&
            compareAndSetState(c, c + SHARED_UNIT)) {
            if (r == 0) {
                firstReader = current;
                firstReaderHoldCount = 1;
            } else if (firstReader == current) {
                firstReaderHoldCount++;
            } else {
                HoldCounter rh = cachedHoldCounter;
                if (rh == null || rh.tid != getThreadId(current))
                    cachedHoldCounter = rh = readHolds.get();
                else if (rh.count == 0)
                    readHolds.set(rh);
                rh.count++;
            }
            return 1;
        }
        return fullTryAcquireShared(current);
    }

    
    final int fullTryAcquireShared(Thread current) {
        HoldCounter rh = null;
        for (;;) {
            int c = getState();
            if (exclusiveCount(c) != 0) {
                if (getExclusiveOwnerThread() != current)
                    return -1;
                // else we hold the exclusive lock; blocking here
                // would cause deadlock.
            } else if (readerShouldBlock()) {
                // Make sure we're not acquiring read lock reentrantly
                if (firstReader == current) {
                    // assert firstReaderHoldCount > 0;
                } else {
                    if (rh == null) {
                        rh = cachedHoldCounter;
                        if (rh == null || rh.tid != getThreadId(current)) {
                            rh = readHolds.get();
                            if (rh.count == 0)
                                readHolds.remove();
                        }
                    }
                    if (rh.count == 0)
                        return -1;
                }
            }
            if (sharedCount(c) == MAX_COUNT)
                throw new Error("Maximum lock count exceeded");
            if (compareAndSetState(c, c + SHARED_UNIT)) {
                if (sharedCount(c) == 0) {
                    firstReader = current;
                    firstReaderHoldCount = 1;
                } else if (firstReader == current) {
                    firstReaderHoldCount++;
                } else {
                    if (rh == null)
                        rh = cachedHoldCounter;
                    if (rh == null || rh.tid != getThreadId(current))
                        rh = readHolds.get();
                    else if (rh.count == 0)
                        readHolds.set(rh);
                    rh.count++;
                    cachedHoldCounter = rh; // cache for release
                }
                return 1;
            }
        }
    }
    。。。。。
}

ReadLock和WriterLock中调用的方法均为Sync中的,不在赘述

构造函数

ReentrantReadWriteLock() 空参构造函数,默认是非公平锁

public ReentrantReadWriteLock() {
    this(false);
}

ReentrantReadWriteLock(boolean) 有参构造函数,可以根据传入参数设置

public ReentrantReadWriteLock(boolean fair) {
    sync = fair ? new FairSync() : new NonfairSync();
    readerLock = new ReadLock(this);
    writerLock = new WriteLock(this);
}

ReentrantReadWriteLock示例

import java.util.concurrent.locks.ReentrantReadWriteLock;

class ReadThread extends Thread {
    private ReentrantReadWriteLock rrwLock;
    
    public ReadThread(String name, ReentrantReadWriteLock rrwLock) {
        super(name);
        this.rrwLock = rrwLock;
    }
    
    public void run() {
        System.out.println(Thread.currentThread().getName() + " trying to lock");
        try {
            rrwLock.readLock().lock();
            System.out.println(Thread.currentThread().getName() + " lock successfully");
            Thread.sleep(5000);        
        } catch (InterruptedException e) {
            e.printStackTrace();
        } finally {
            rrwLock.readLock().unlock();
            System.out.println(Thread.currentThread().getName() + " unlock successfully");
        }
    }
}

class WriteThread extends Thread {
    private ReentrantReadWriteLock rrwLock;
    
    public WriteThread(String name, ReentrantReadWriteLock rrwLock) {
        super(name);
        this.rrwLock = rrwLock;
    }
    
    public void run() {
        System.out.println(Thread.currentThread().getName() + " trying to lock");
        try {
            rrwLock.writeLock().lock();
            System.out.println(Thread.currentThread().getName() + " lock successfully");    
        } finally {
            rrwLock.writeLock().unlock();
            System.out.println(Thread.currentThread().getName() + " unlock successfully");
        }
    }
}

public class ReentrantReadWriteLockDemo {
    public static void main(String[] args) {
        ReentrantReadWriteLock rrwLock = new ReentrantReadWriteLock();
        ReadThread rt1 = new ReadThread("rt1", rrwLock);
        ReadThread rt2 = new ReadThread("rt2", rrwLock);
        WriteThread wt1 = new WriteThread("wt1", rrwLock);
        rt1.start();
        rt2.start();
        wt1.start();
    } 
}