JUC:JAVA7/8中的ConcurrentHashMap

586 阅读27分钟

前言

为什么要使用ConcurrentHashMap

  • 线程不安全的HashMap

在java7中,多线程put操作会导致循环链表。

虽然在java8中,修复了这个问题,但是还是会造成数据覆盖和扩容时的数据丢失。

  • 效率低下的HashTable

HashTable容器使用synchronized来保证线程安全,效率低下。

JAVA7 ConcurrentHashMap

ConcurrentHashMap是由Segment数据结构和HashEntry数组结构组成。Segment是一种可重入锁(ReentrantLock),在ConcurrentHashMap里扮演锁的角色;HashEntry则用于存储键值对数据。一个ConcurrentHashMap里包含一个Segment数组。Segment的结构和HashMap类似,是一种数组和链表结构。一个Segment里包含一个HashEntry数组,每个HashEntry是一个链表结构的元素,每个Segment守护着一个HashEntry数组里的元素,当对HashEntry数组的数据进行修改时,必须首先获得与它对应的Segment锁。

简单理解就是,ConcurrentHashMap 是一个 Segment 数组,Segment 通过继承 ReentrantLock 来进行加锁,所以每次需要加锁的操作锁住的是一个 segment,这样只要保证每个 Segment 是线程安全的,也就实现了全局的线程安全。

属性

    static final int DEFAULT_INITIAL_CAPACITY = 16;
    static final float DEFAULT_LOAD_FACTOR = 0.75f;
    static final int DEFAULT_CONCURRENCY_LEVEL = 16;
    static final int MAXIMUM_CAPACITY = 1 << 30;
    static final int MIN_SEGMENT_TABLE_CAPACITY = 2;
    static final int MAX_SEGMENTS = 1 << 16;
    static final int RETRIES_BEFORE_LOCK = 2
    final int segmentMask;
    final int segmentShift;
    final Segment<K,V>[] segments;

构造函数

我们先来看看无参构造函数。

    public ConcurrentHashMap() {
        this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
    }

接着看另外一个构造函数。

    public ConcurrentHashMap(int initialCapacity,
                             float loadFactor, int concurrencyLevel) {
        // 负载因子小于等于0或初始化容量小于0或同步等级小于等于0抛异常
        if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
            throw new IllegalArgumentException();
        // 同步等级最大为MAX_SEGMENTS
        if (concurrencyLevel > MAX_SEGMENTS)
            concurrencyLevel = MAX_SEGMENTS;
        int sshift = 0;
        int ssize = 1;
        // 保证ssize大于等于同步等级且是2的n次方
        while (ssize < concurrencyLevel) {
            ++sshift;
            ssize <<= 1;
        }
        this.segmentShift = 32 - sshift;
        this.segmentMask = ssize - 1;
        // 最大初始化容量为MAXIMUM_CAPACITY
        if (initialCapacity > MAXIMUM_CAPACITY)
            initialCapacity = MAXIMUM_CAPACITY;
        // 这里计算每个segment需要初始化多少容量
        int c = initialCapacity / ssize;
        if (c * ssize < initialCapacity)
            ++c;
        // 每个segment容量最小为MIN_SEGMENT_TABLE_CAPACITY
        int cap = MIN_SEGMENT_TABLE_CAPACITY;
        while (cap < c)
            // 且segement容量为2的n次方
            cap <<= 1;
        // 初始化Segement0
        Segment<K,V> s0 =
            new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                             (HashEntry<K,V>[])new HashEntry[cap]);
        // 初始化Segment数组
        Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
        // 往数组写入 segment[0]
        UNSAFE.putOrderedObject(ss, SBASE, s0); 
        this.segments = ss;
    }

我们来看一下初始化做了哪些操作。

  • 初始化了Segment,默认数组长度为 16,不可以扩容。
  • Segment[i] 的默认大小为 2,负载因子是0.75,得出初始阈值为1.5,也就是以后插入第一个元素不会触发扩容,插入第二个会进行第一次扩容。
  • 当前 segmentShift 的值为 32 - 4 = 28,segmentMask 为 16 - 1 = 15,姑且把它们简单翻译为移位数和掩码,这两个值马上就会用到。

put

    public V put(K key, V value) {
        Segment<K,V> s;
        if (value == null)
            throw new NullPointerException();
        // 计算key的hash值
        int hash = hash(key);
        // 根据Key的哈希值找到segment数组中的位置
        // hash值右移segmentShift位,和segmentMask做与运算
        // 也就是key的hash值的前n位来确定该key存在于哪个segment
        // 这个n位就是segment数组容量(size-1)的位数
        int j = (hash >>> segmentShift) & segmentMask;
        // 上面说过,初始化只初始化了segment[0]
        // 判断定位到的segment是否为null
        if ((s = (Segment<K,V>)UNSAFE.getObject          
             (segments, (j << SSHIFT) + SBASE)) == null) 
            // 需要初始化当前的segment
            s = ensureSegment(j);
        // 真正执行put
        return s.put(key, hash, value, false);
    }

ensureSegment 初始化segment

ConcurrentHashMap 初始化的时候会初始化第一个槽 segment[0],对于其他槽来说,在插入第一个值的时候进行初始化。

这里需要考虑并发,因为很可能会有多个线程同时进来初始化同一个槽 segment[k],不过只要有一个成功了就可以。

    private Segment<K,V> ensureSegment(int k) {
        final Segment<K,V>[] ss = this.segments;
        long u = (k << SSHIFT) + SBASE; 
        Segment<K,V> seg;
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
            // 这里看到为什么之前要初始化 segment[0] 了
            // 使用当前 segment[0] 处的数组长度和负载因子来初始化 segment[k]
            // 为什么要用“当前”,因为 segment[0] 可能早就扩容过了
            Segment<K,V> proto = ss[0]; 
            int cap = proto.table.length; 
            float lf = proto.loadFactor;
            int threshold = (int)(cap * lf);
            // 初始化 segment[k] 内部的数组
            HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
            if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                == null) { 
                // 再次检查一遍该槽是否被其他线程初始化了。
                Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
                // 使用 while 循环,内部用 CAS,当前线程成功设值或其他线程成功设值后,退出
                while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                       == null) {
                    if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                        break;
                }
            }
        }
        return seg;
    }

总的来说,ensureSegment(int k) 比较简单,对于并发操作使用 CAS 进行控制。

这里为什么要搞一个 while 循环,CAS 失败不就代表有其他线程成功了吗,为什么要再进行判断?

如果当前线程 CAS 失败,这里的 while 循环是为了将 seg 赋值返回。

Segment#put

        final V put(K key, int hash, V value, boolean onlyIfAbsent) {
            // 在往该segment里写入前,需要先获取该segment的独占锁
            HashEntry<K,V> node = tryLock() ? null :
                scanAndLockForPut(key, hash, value);
            V oldValue;
            try {
                // 获取该segment的内部数组
                HashEntry<K,V>[] tab = table;
                int index = (tab.length - 1) & hash;
                // 获取数组下标为index链表的表头
                HashEntry<K,V> first = entryAt(tab, index);
                for (HashEntry<K,V> e = first;;) {
                    // 如果表头不为null,则表示哈希碰撞了
                    if (e != null) {
                        K k;
                        // 判断是否同一个key
                        // 首先判断内存地址是否同一,如果是则不需要后续判断了
                        if ((k = e.key) == key ||
                            // 如果内存地址不同,则判断hash值是否同一个,如果不是也不需要后续判断了
                            // 最后才equals
                            (e.hash == hash && key.equals(k))) {
                            // 记录旧值
                            oldValue = e.value;
                            if (!onlyIfAbsent) {
                                // 覆盖旧值
                                e.value = value;
                                ++modCount;
                            }
                            break;
                        }
                        e = e.next;
                    }
                    else {
                        // node如果不为null,直接设置为表头
                        if (node != null)
                            node.setNext(first);
                        else
                            // node如果为null,初始化并设置表头
                            node = new HashEntry<K,V>(hash, key, value, first);
                        int c = count + 1;
                        // 如果size大于阈值,扩容
                        if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                            rehash(node);
                        else
                            // 没有达到阈值,将 node 放到数组 tab 的 index 位置,
                            // 其实就是将新的节点设置成原链表的表头
                            setEntryAt(tab, index, node);
                        ++modCount;
                        count = c;
                        oldValue = null;
                        break;
                    }
                }
            } finally {
                unlock();
            }
            return oldValue;
        }

由于有独占锁的保护,整个流程还是比较简单的。

Segment#scanAndLockForPut 获取写入锁

        private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
            HashEntry<K,V> first = entryForHash(this, hash);
            HashEntry<K,V> e = first;
            HashEntry<K,V> node = null;
            int retries = -1; 
            // tryLock失败则会进入这个循环
            while (!tryLock()) {
                HashEntry<K,V> f;
                // 如果是第一次重试
                if (retries < 0) {
                    // 如果该位置链表为null
                    if (e == null) {
                        // 如果没有初始化过node则初始化一个node
                        if (node == null) 
                            node = new HashEntry<K,V>(hash, key, value, null);
                        retries = 0;
                    }
                    // 如果哈希碰撞了,重试次数为0
                    else if (key.equals(e.key))
                        retries = 0;
                    else
                        // 顺着链表往下走
                        e = e.next;
                }
                // 如果大于最大重试次数了
                else if (++retries > MAX_SCAN_RETRIES) {
                    // 直接调用lock阻塞至获取锁
                    lock();
                    break;
                }
                else if ((retries & 1) == 0 &&
                        // 如果这个过程中链表的表头换了,代表有新的元素插入到该链表了
                         (f = entryForHash(this, hash)) != first) {
                    // 策略就是重新走一次scanAndLockForPut
                    e = first = f; 
                    retries = -1;
                }
            }
            return node;
        }
  • 循环获取锁
  • 如果获取锁失败,会循着表头往下走,走到底或者发现产生了hash碰撞才不走,走到底则初始化node,产生了hash碰撞了不初始化。
  • 上一步走完才会加重试次数并判断是否超过最大重试次数,如果超过最大重试次数直接调用lock阻塞获取锁。
  • 然后每循环两次判断一下链表头是否被换掉了,如果换掉了就代表链表有变化,重新走一次scanAndLockForPut方法。

Segment#rehash 扩容

    // 方法参数中的node是本次新put的数据
    private void rehash(HashEntry<K,V> node) {
        HashEntry<K,V>[] oldTable = table;
        int oldCapacity = oldTable.length;
        // 两倍扩容
        int newCapacity = oldCapacity << 1;
        // 新的阈值
        threshold = (int)(newCapacity * loadFactor);
        HashEntry<K,V>[] newTable =
            (HashEntry<K,V>[]) new HashEntry[newCapacity];
        // 新的掩码
        int sizeMask = newCapacity - 1;
        // 遍历数组内每一个链表
        for (int i = 0; i < oldCapacity ; i++) {
            HashEntry<K,V> e = oldTable[i];
            if (e != null) {
                HashEntry<K,V> next = e.next;
                int idx = e.hash & sizeMask;
                // 如果下一个为当前链表的next结点为null,直接为当前节点设置新的位置
                if (next == null)   
                    newTable[idx] = e;
                else { 
                    // 如果链表上不止一个值
                    HashEntry<K,V> lastRun = e;
                    int lastIdx = idx;
                    // 找到一个lastRun,也就是说lastRun之后的所有链表节点是放在同一个链表里的
                    for (HashEntry<K,V> last = next;
                         last != null;
                         last = last.next) {
                        int k = last.hash & sizeMask;
                        if (k != lastIdx) {
                            lastIdx = k;
                            lastRun = last;
                        }
                    }
                    // 处理lastRun
                    newTable[lastIdx] = lastRun;
                    // 处理lastRun之前的节点,这些节点要不和lastRun在一个链表中,要不在另外一个链表中
                    for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                        V v = p.value;
                        int h = p.hash;
                        int k = h & sizeMask;
                        HashEntry<K,V> n = newTable[k];
                        newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                    }
                }
            }
        }
        // 设置当前新加入的节点
        int nodeIndex = node.hash & sizeMask; // add the new node
        node.setNext(newTable[nodeIndex]);
        newTable[nodeIndex] = node;
        table = newTable;
    }

这里的扩容比较要复杂,代码难懂一点。上面有两个挨着的 for 循环,第一个 for 有什么用呢?

仔细一看发现,如果没有第一个 for 循环,也是可以工作的,但是,这个 for 循环下来,如果 lastRun 的后面还有比较多的节点,那么这次就是值得的。因为我们只需要克隆 lastRun 前面的节点,后面的一串节点跟着 lastRun 走就是了,不需要做任何操作。

remove 删除

        final V remove(Object key, int hash, Object value) {
            if (!tryLock())
                scanAndLock(key, hash);
            V oldValue = null;
            try {
                HashEntry<K,V>[] tab = table;
                int index = (tab.length - 1) & hash;
                HashEntry<K,V> e = entryAt(tab, index);
                HashEntry<K,V> pred = null;
                while (e != null) {
                    K k;
                    HashEntry<K,V> next = e.next;
                    if ((k = e.key) == key ||
                        (e.hash == hash && key.equals(k))) {
                        V v = e.value;
                        if (value == null || value == v || value.equals(v)) {
                            if (pred == null)
                                setEntryAt(tab, index, next);
                            else
                                pred.setNext(next);
                            ++modCount;
                            --count;
                            oldValue = v;
                        }
                        break;
                    }
                    pred = e;
                    e = next;
                }
            } finally {
                unlock();
            }
            return oldValue;
        }

加锁然后再删除。

get 获取

    public V get(Object key) {
        Segment<K,V> s; 
        HashEntry<K,V>[] tab;
        // 计算hash值
        int h = hash(key);
        // 算出segment中的位置
        long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
        // 判断segment以及table是否为空
        if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
            (tab = s.table) != null) {
            for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
                     (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
                 e != null; e = e.next) {
                K k;
                // 遍历链表找到具体的key
                if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                    return e.value;
            }
        }
        return null;
    }

get操作相对来说比较简单,就不赘述了。

size 数量

    public int size() {
        final Segment<K,V>[] segments = this.segments;
        int size;
        // 是否超过了int的最大值
        boolean overflow; 
        long sum;         
        long last = 0L;  
        int retries = -1; 
        try {
            for (;;) {
                // 如果重试次数等于2,直接循环对所有segment加锁
                if (retries++ == RETRIES_BEFORE_LOCK) {
                    for (int j = 0; j < segments.length; ++j)
                        ensureSegment(j).lock(); 
                }
                sum = 0L;
                size = 0;
                overflow = false;
                // 优先考虑不加锁策略
                for (int j = 0; j < segments.length; ++j) {
                    Segment<K,V> seg = segmentAt(segments, j);
                    if (seg != null) {
                        sum += seg.modCount;
                        int c = seg.count;
                        if (c < 0 || (size += c) < 0)
                            overflow = true;
                    }
                }
                // 重试第一次时,last为0,如果concurrenthashmap中有元素
                // 那么一定会有第二次循环
                if (sum == last)
                    break;
                last = sum;
            }
        } finally {
            if (retries > RETRIES_BEFORE_LOCK) {
                for (int j = 0; j < segments.length; ++j)
                    segmentAt(segments, j).unlock();
            }
        }
        return overflow ? Integer.MAX_VALUE : size;
    }

还记得我们put操作时的modcount吗,如果当前ConcurrentHashMap有值,则不加锁计算每个segment的size总量以及modcount总量,计算两次如果modcount总量是一致的,就证明没有并发存在,直接返回size。如果不一致,就需要对每一个segment加锁然后计算size总量。

JAVA7 ConcurrentHashMap并发总结

现在我们已经说完了 put 过程和 get 过程,我们可以看到 get 过程中是没有加锁的,那自然我们就需要去考虑并发问题。

添加节点的操作 put 和删除节点的操作 remove 都是要加 segment 上的独占锁的,所以它们之间自然不会有问题,我们需要考虑的问题就是 get 的时候在同一个 segment 中发生了 put 或 remove 操作。

  • put 操作的线程安全性
    • 初始化槽,这个我们之前就说过了,使用了 CAS 来初始化 Segment 中的数组。
    • 添加节点到链表的操作是插入到表头的,所以,如果这个时候 get 操作在链表遍历的过程已经到了中间,是不会影响的。当然,另一个并发问题就是 get 操作在 put 之后,需要保证刚刚插入表头的节点被读取,这个依赖于 setEntryAt 方法中使用的 UNSAFE.putOrderedObject。
    • 扩容。扩容是新创建了数组,然后进行迁移数据,最后面将 newTable 设置给属性 table。所以,如果 get 操作此时也在进行,那么也没关系,如果 get 先行,那么就是在旧的 table 上做查询操作;而 put 先行,那么 put 操作的可见性保证就是 table 使用了 volatile 关键字。
  • remove操作的线程安全性
    • get 操作需要遍历链表,但是 remove 操作会"破坏"链表。
    • 如果 remove 破坏的节点 get 操作已经过去了,那么这里不存在任何问题。
    • 如果 remove 先破坏了一个节点,分两种情况考虑。 1、如果此节点是头结点,那么需要将头结点的 next 设置为数组该位置的元素,table 虽然使用了 volatile 修饰,但是 volatile 并不能提供数组内部操作的可见性保证,所以源码中使用了 UNSAFE 来操作数组,请看方法 setEntryAt。2、如果要删除的节点不是头结点,它会将要删除节点的后继节点接到前驱节点中,这里的并发保证就是 next 属性是 volatile 的。

JAVA8 ConcurrentHashMap

ConcurrentHashMap在java8中做了不小的改动,取消了segment,引入了红黑树,用synchronized替代了reentrantlock。

我们先用一个示意图来描述下其结构:

结构上和 Java8 的 HashMap 基本上一样,不过它要保证线程安全性,所以在源码上确实要复杂一些。

构造函数

    public ConcurrentHashMap() {
    }

无参构造函数啥事都没做。

    public ConcurrentHashMap(int initialCapacity) {
        if (initialCapacity < 0)
            throw new IllegalArgumentException();
        int cap = ((initialCapacity >= (MAXIMUM_CAPACITY >>> 1)) ?
                   MAXIMUM_CAPACITY :
                   // (1.5 * initialCapacity + 1)向上取最近的2的n次方。
                   tableSizeFor(initialCapacity + (initialCapacity >>> 1) + 1));
        this.sizeCtl = cap;
    }

设置了sizeCtl为 (1.5 * initialCapacity + 1) 向上取最近的2的n次方,先不说sizeCtl在这里是干嘛用的,各位同学继续往下看。

put插入

    public V put(K key, V value) {
        return putVal(key, value, false);
    }
    final V putVal(K key, V value, boolean onlyIfAbsent) {
        if (key == null || value == null) throw new NullPointerException();
        // 计算hash值
        int hash = spread(key.hashCode());
        int binCount = 0;
        for (Node<K,V>[] tab = table;;) {
            Node<K,V> f; int n, i, fh;
            // 如果table为null,也就是第一次put,那么初始化table     
            if (tab == null || (n = tab.length) == 0)
                tab = initTable();
            // 如果table已经初始化过了,则通过hash值对应的下标得到第一个节点f
            // 如果f为null,则证明该下标第一次参与put
            else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
                // 采用CAS的方式将新的值放入即可
                if (casTabAt(tab, i, null,
                             new Node<K,V>(hash, key, value, null)))
                    break;                   
            }
            // 得到第一个节点的hash值fh,如果该hash值等于MOVED(-1)则帮助扩容
            // hash值为什么会等于-1,继续往下看
            else if ((fh = f.hash) == MOVED)
                tab = helpTransfer(tab, f);
            else {
                // 走到这里就表明发生了hash碰撞
                V oldVal = null;
                // 对第一个节点f加监视器锁
                synchronized (f) {
                    // 在获取一次下标i对应的节点,如果不是f则表明在获取监视器锁之前发生了扩容
                    if (tabAt(tab, i) == f) {
                        // 如果f的hash值大于等于0,则表明f为链表
                        if (fh >= 0) {
                            binCount = 1;
                            // 从链表头往表后找
                            for (Node<K,V> e = f;; ++binCount) {
                                K ek;
                                // 如果key值相等,则替换
                                if (e.hash == hash &&
                                    ((ek = e.key) == key ||
                                     (ek != null && key.equals(ek)))) {
                                    oldVal = e.val;
                                    if (!onlyIfAbsent)
                                        e.val = value;
                                    break;
                                }
                                Node<K,V> pred = e;
                                // 如果链表内没有找到相等的则新建一个node加入到链表尾部
                                if ((e = e.next) == null) {
                                    pred.next = new Node<K,V>(hash, key,
                                                              value, null);
                                    break;
                                }
                            }
                        }
                        // 如果节点为红黑树
                        else if (f instanceof TreeBin) {
                            Node<K,V> p;
                            binCount = 2;
                            // 则采用红黑树的方式插入
                            if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key,
                                                           value)) != null) {
                                oldVal = p.val;
                                if (!onlyIfAbsent)
                                    p.val = value;
                            }
                        }
                    }
                }
                // 如果节点插入成功
                if (binCount != 0) {
                    // 判断是否需要扩容
                    if (binCount >= TREEIFY_THRESHOLD)
                        treeifyBin(tab, i);
                    // 结束循环
                    if (oldVal != null)
                        return oldVal;
                    break;
                }
            }
        }
        addCount(1L, binCount);
        return null;
    }
  • 首先计算hash值。
  • 如果table未空,则初始化table。
  • 如果hash对应的下标为空,则直接初始化一个Node并采用cas的方式设置进去。
  • 如果正在扩容,则调用helpTransfer帮助扩容。
  • 如果产生了hash碰撞,也就是hash对应的node不为空,则对该获取该node的监视器锁,如果该node对应的是链表,则以链表的方式插入,如果是红黑树则以红黑树的方式插入。
  • 最后判断是否插入成功,如果没有插入成功则继续循环,插入成功则判断是否需要扩容。
  • 最后addCount计数。

initTable 初始化table

    private final Node<K,V>[] initTable() {
        Node<K,V>[] tab; int sc;
        while ((tab = table) == null || tab.length == 0) {
            // 如果sizeCtl小于0,则表明正在扩容,让出cpu
            if ((sc = sizeCtl) < 0)
                Thread.yield(); 
            // cas方式将sizeCtl设置为-1
            else if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) {
                try {
                    // 再判断一次是否已经初始化过了
                    if ((tab = table) == null || tab.length == 0) {
                        // 设置容量
                        int n = (sc > 0) ? sc : DEFAULT_CAPACITY;
                        // 初始化table数组
                        Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
                        table = tab = nt;
                        // 设置阈值为n*(1-0.25)
                        sc = n - (n >>> 2);
                    }
                } finally {
                    // 最后设置阈值
                    sizeCtl = sc;
                }
                break;
            }
        }
        return tab;
    }
  • 循环判断是否需要初始化
  • 如果sizeCtl<0,表示有线程正在初始化。
  • 如果没有初始化,则使用CAS方式将sizeCtl设置为-1(这里是用volatile变量代替锁的思想)。
  • 重新判断一次table是否未初始化(双重校验锁思想)。
  • 初始化table
  • 最后设置阈值(0.75 * size)。

helpTransfer 帮助扩容

首先需要介绍一下,ForwardingNode 这个节点类型,

    static final class ForwardingNode<K,V> extends Node<K,V> {
        final Node<K,V>[] nextTable;
        ForwardingNode(Node<K,V>[] tab) {
            super(MOVED, null, null, null);
            this.nextTable = tab;
        }
    }

这个节点内部保存了一 nextTable 引用,它指向一张 hash 表。在扩容操作中,我们需要对每个桶中的结点进行分离和转移,如果某个桶结点中所有节点都已经迁移完成了(已经被转移到新表 nextTable 中了),那么会在原 table 表的该位置挂上一个 ForwardingNode 结点,说明此桶已经完成迁移。

ForwardingNode 继承自 Node 结点,并且它唯一的构造函数将构建一个键,值,next 都为 null 的结点,反正它就是个标识,无需那些属性。但是 hash 值却为 MOVED。

所以,我们在 putVal 方法中遍历整个 hash 表的桶结点,如果遇到 hash 值等于 MOVED,说明已经有线程正在扩容 rehash 操作,整体上还未完成,不过我们要插入的桶的位置已经完成了所有节点的迁移。

由于检测到当前哈希表正在扩容,于是让当前线程去协助扩容。

    final Node<K,V>[] helpTransfer(Node<K,V>[] tab, Node<K,V> f) {
        Node<K,V>[] nextTab; int sc;
        if (tab != null && (f instanceof ForwardingNode) &&
            (nextTab = ((ForwardingNode<K,V>)f).nextTable) != null) {
            //返回一个 16 位长度的扩容校验标识
            int rs = resizeStamp(tab.length);
            while (nextTab == nextTable && table == tab &&
                   (sc = sizeCtl) < 0) {
                //sizeCtl 如果处于扩容状态的话
                //前 16 位是数据校验标识,后 16 位是当前正在扩容的线程总数
                 //这里判断校验标识是否相等,如果校验符不等或者扩容操作已经完成了,直接退出循环,不用协助它们扩容了
                if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
                    sc == rs + MAX_RESIZERS || transferIndex <= 0)
                    break;
                //否则调用 transfer 帮助它们进行扩容
                //sc + 1 标识增加了一个线程进行扩容
                if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1)) {
                    transfer(tab, nextTab);
                    break;
                }
            }
            return nextTab;
        }
        return table;
    }

transfer 转移

transfer 方法比较,我们分几个部分来细说。

第一部分

    private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
        int n = tab.length, stride;
        //计算单个线程允许处理的最少table桶首节点个数,不能小于 16
        if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE)
            stride = MIN_TRANSFER_STRIDE; 
        //刚开始扩容,初始化 nextTab 
        if (nextTab == null) {
            try {
                @SuppressWarnings("unchecked")
                Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1];
                nextTab = nt;
            } catch (Throwable ex) {
                sizeCtl = Integer.MAX_VALUE;
                return;
            }
            nextTable = nextTab;
            //transferIndex 指向最后一个桶,方便从后向前遍历 
            transferIndex = n;
        }
        int nextn = nextTab.length;
        //定义 ForwardingNode 用于标记迁移完成的桶
        ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab);

主要完成的是对单个线程能处理的最少桶结点个数的计算和一些属性的初始化操作。

第二部分

boolean advance = true;
boolean finishing = false;
//i 指向当前桶,bound 指向当前线程需要处理的桶结点的区间下限
for (int i = 0, bound = 0;;) {
       Node<K,V> f; int fh;
       //这个 while 循环的目的就是通过 --i 遍历当前线程所分配到的桶结点
       //一个桶一个桶的处理
       while (advance) {
           int nextIndex, nextBound;
           if (--i >= bound || finishing)
               advance = false;
           //transferIndex <= 0 说明已经没有需要迁移的桶了
           else if ((nextIndex = transferIndex) <= 0) {
               i = -1;
               advance = false;
           }
           //更新 transferIndex
           //为当前线程分配任务,处理的桶结点区间为(nextBound,nextIndex)
           else if (U.compareAndSwapInt(this, TRANSFERINDEX, nextIndex,nextBound = (nextIndex > stride ? nextIndex - stride : 0))) {
               bound = nextBound;
               i = nextIndex - 1;
               advance = false;
           }
       }
       //当前线程所有任务完成
       if (i < 0 || i >= n || i + n >= nextn) {
           int sc;
           if (finishing) {
               nextTable = null;
               table = nextTab;
               sizeCtl = (n << 1) - (n >>> 1);
               return;
           }
           if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
               if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
                   return;
               finishing = advance = true;
               i = n; 
           }
       }
       //待迁移桶为空,那么在此位置 CAS 添加 ForwardingNode 结点标识该桶已经被处理过了
       else if ((f = tabAt(tab, i)) == null)
           advance = casTabAt(tab, i, null, fwd);
       //如果扫描到 ForwardingNode,说明此桶已经被处理过了,跳过即可
       else if ((fh = f.hash) == MOVED)
           advance = true; 

每个新参加进来扩容的线程必然先进 while 循环的最后一个判断条件中去领取自己需要迁移的桶的区间。然后 i 指向区间的最后一个位置,表示迁移操作从后往前的做。接下来的几个判断就是实际的迁移结点操作了。等我们大致介绍完成第三部分的源码再回来对各个判断条件下的迁移过程进行详细的叙述。

第三部分

else {
    synchronized (f) {
        if (tabAt(tab, i) == f) {
            Node<K,V> ln, hn;
            //链表的迁移操作
            if (fh >= 0) {
                int runBit = fh & n;
                Node<K,V> lastRun = f;
                //整个 for 循环为了找到整个桶中最后连续的 fh & n 不变的结点
                for (Node<K,V> p = f.next; p != null; p = p.next) {
                    int b = p.hash & n;
                    if (b != runBit) {
                        runBit = b;
                        lastRun = p;
                    }
                }
                if (runBit == 0) {
                    ln = lastRun;
                    hn = null;
                }
                else {
                    hn = lastRun;
                    ln = null;
                }
                //如果fh&n不变的链表的runbit都是0,则nextTab[i]内元素ln前逆序,ln及其之后顺序
                //否则,nextTab[i+n]内元素全部相对原table逆序
                //这是通过一个节点一个节点的往nextTab添加
                for (Node<K,V> p = f; p != lastRun; p = p.next) {
                    int ph = p.hash; K pk = p.key; V pv = p.val;
                    if ((ph & n) == 0)
                        ln = new Node<K,V>(ph, pk, pv, ln);
                    else
                        hn = new Node<K,V>(ph, pk, pv, hn);
                }
                //把两条链表整体迁移到nextTab中
                setTabAt(nextTab, i, ln);
                setTabAt(nextTab, i + n, hn);
                //将原桶标识位已经处理
                setTabAt(tab, i, fwd);
                advance = true;
            }
            //红黑树的复制算法,不再赘述
            else if (f instanceof TreeBin) {
                TreeBin<K,V> t = (TreeBin<K,V>)f;
                TreeNode<K,V> lo = null, loTail = null;
                TreeNode<K,V> hi = null, hiTail = null;
                int lc = 0, hc = 0;
                for (Node<K,V> e = t.first; e != null; e = e.next) {
                    int h = e.hash;
                    TreeNode<K,V> p = new TreeNode<K,V>(h, e.key, e.val, null, null);
                    if ((h & n) == 0) {
                        if ((p.prev = loTail) == null)
                            lo = p;
                        else
                            loTail.next = p;
                    loTail = p;
                    ++lc;
                    }
                    else {
                        if ((p.prev = hiTail) == null)
                            hi = p;
                        else
                            hiTail.next = p;
                    hiTail = p;
                    ++hc;
                    }
                }
                ln = (lc <= UNTREEIFY_THRESHOLD) ? untreeify(lo) :(hc != 0) ? new TreeBin<K,V>(lo) : t;
                hn = (hc <= UNTREEIFY_THRESHOLD) ? untreeify(hi) :(lc != 0) ? new TreeBin<K,V>(hi) : t;
                setTabAt(nextTab, i, ln);
                setTabAt(nextTab, i + n, hn);
                setTabAt(tab, i, fwd);
                advance = true;
           }

那么至此,有关迁移的几种情况已经介绍完成了,下面我们整体上把控一下整个扩容和迁移过程。

首先,每个线程进来会先领取自己的任务区间,然后开始 --i 来遍历自己的任务区间,对每个桶进行处理。如果遇到桶的头结点是空的,那么使用 ForwardingNode 标识该桶已经被处理完成了。如果遇到已经处理完成的桶,直接跳过进行下一个桶的处理。如果是正常的桶,对桶首节点加锁,正常的迁移即可,迁移结束后依然会将原表的该位置标识位已经处理。

当 i < 0,说明本线程处理速度够快的,整张表的最后一部分已经被它处理完了,现在需要看看是否还有其他线程在自己的区间段还在迁移中。这是退出的逻辑判断部分:

       if (i < 0 || i >= n || i + n >= nextn) {
           int sc;
           if (finishing) {
               nextTable = null;
               table = nextTab;
               sizeCtl = (n << 1) - (n >>> 1);
               return;
           }
           if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
               if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
                   return;
               finishing = advance = true;
               i = n; 
           }
       }

finnish 是一个标志,如果为 true 则说明整张表的迁移操作已经全部完成了,我们只需要重置 table 的引用并将 nextTable 赋为空即可。否则,CAS 式的将 sizeCtl 减一,表示当前线程已经完成了任务,退出扩容操作。

如果退出成功,那么需要进一步判断是否还有其他线程仍然在执行任务。

if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
   return;

我们说过 resizeStamp(n) 返回的是对 n 的一个数据校验标识,占 16 位。而 RESIZE_STAMP_SHIFT 的值为 16,那么位运算后,整个表达式必然在右边空出 16 个零。也正如我们所说的,sizeCtl 的高 16 位为数据校验标识,低 16 为表示正在进行扩容的线程数量。

(resizeStamp(n) << RESIZE_STAMP_SHIFT) + 2 表示当前只有一个线程正在工作,相对应的,如果 (sc - 2) == resizeStamp(n) << RESIZE_STAMP_SHIFT,说明当前线程就是最后一个还在扩容的线程,那么会将 finishing 标识为 true,并在下一次循环中退出扩容方法。

treeifyBin 链表转为红黑树

    private final void treeifyBin(Node<K,V>[] tab, int index) {
        Node<K,V> b; int n, sc;
        if (tab != null) {
            // 如果数组长度小于MIN_TREEIFY_CAPACITY(64)的时候会首先扩容
            if ((n = tab.length) < MIN_TREEIFY_CAPACITY)
                // 扩容
                tryPresize(n << 1);
            // 获取头节点,且头节点的hash值要大于等于0(这里hash值为-1表示有其他线程正在扩容)
            else if ((b = tabAt(tab, index)) != null && b.hash >= 0) {
                // 获取头节点的监视器锁
                synchronized (b) {
                    // 当前头节点和之前头节点是否一致,如果一致则将链表转换为红黑树
                    if (tabAt(tab, index) == b) {
                        TreeNode<K,V> hd = null, tl = null;
                        // 将链表转换为红黑树
                        for (Node<K,V> e = b; e != null; e = e.next) {
                            TreeNode<K,V> p =
                                new TreeNode<K,V>(e.hash, e.key, e.val,
                                                  null, null);
                            if ((p.prev = tl) == null)
                                hd = p;
                            else
                                tl.next = p;
                            tl = p;
                        }
                        // 将红黑树放入数组中对应的下标
                        setTabAt(tab, index, new TreeBin<K,V>(hd));
                    }
                }
            }
        }
    }
  • 首先判断是否需要转换为红黑树,如果table的长度不足64,则优先考虑扩容,扩容为两倍。
  • 如果需要转换为红黑树,则获取头节点的监视器锁,循环转换为红黑树,填充到对应的位置即可。

tryPresize 扩容

    private final void tryPresize(int size) {
        // c 为 size 的 1.5 倍,再加 1,再往上取最近的 2 的 n 次方。
        int c = (size >= (MAXIMUM_CAPACITY >>> 1)) ? MAXIMUM_CAPACITY :
            tableSizeFor(size + (size >>> 1) + 1);
        int sc;
        // 大于0则表示不在扩容或初始化
        while ((sc = sizeCtl) >= 0) {
            Node<K,V>[] tab = table; int n;
            // 如果table未初始化就初始化,和initTable流程完成相同。
            if (tab == null || (n = tab.length) == 0) {
                n = (sc > c) ? sc : c;
                if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) {
                    try {
                        if (table == tab) {
                            @SuppressWarnings("unchecked")
                            Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
                            table = nt;
                            sc = n - (n >>> 2);
                        }
                    } finally {
                        sizeCtl = sc;
                    }
                }
            }
            // 如果c小于阈值或者n大于等于数组最大值,则直接结束。
            else if (c <= sc || n >= MAXIMUM_CAPACITY)
                break;
            else if (tab == table) {
                int rs = resizeStamp(n);
                if (sc < 0) {
                    Node<K,V>[] nt;
                    if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
                        sc == rs + MAX_RESIZERS || (nt = nextTable) == null ||
                        transferIndex <= 0)
                        break;
                    //    用 CAS 将 sizeCtl 加 1,然后执行 transfer 方法
                    //    此时 nextTab 不为 null
                    if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))
                        transfer(tab, nt);
                }
                //  将 sizeCtl 设置为 (rs << RESIZE_STAMP_SHIFT) + 2)
                //  调用 transfer 方法,此时 nextTab 参数为 null
                else if (U.compareAndSwapInt(this, SIZECTL, sc,
                                             (rs << RESIZE_STAMP_SHIFT) + 2))
                    transfer(tab, null);
            }
        }
    }

这个方法的核心在于 sizeCtl 值的操作,首先将其设置为一个负数,然后执行 transfer(tab, null),再下一个循环将 sizeCtl 加 1,并执行 transfer(tab, nt),之后可能是继续 sizeCtl 加 1,并执行 transfer(tab, nt)。

addCount

    private final void addCount(long x, int check) {
        CounterCell[] as; long b, s;
        //如果更新失败才会进入的 if 的主体代码中
        if ((as = counterCells) != null ||
            !U.compareAndSwapLong(this, BASECOUNT, b = baseCount, s = b + x)) {
            CounterCell a; long v; int m;
            boolean uncontended = true;
             //高并发下 CAS 失败会执行 fullAddCount 方法
            if (as == null || (m = as.length - 1) < 0 ||
                (a = as[ThreadLocalRandom.getProbe() & m]) == null ||
                !(uncontended =
                  U.compareAndSwapLong(a, CELLVALUE, v = a.value, v + x))) {
                fullAddCount(x, uncontended);
                return;
            }
            if (check <= 1)
                return;
            s = sumCount();
        }
        //判断是否需要扩容
        if (check >= 0) {
            Node<K,V>[] tab, nt; int n, sc;
            while (s >= (long)(sc = sizeCtl) && (tab = table) != null &&
                   (n = tab.length) < MAXIMUM_CAPACITY) {
                int rs = resizeStamp(n);
                if (sc < 0) {
                    if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
                        sc == rs + MAX_RESIZERS || (nt = nextTable) == null ||
                        transferIndex <= 0)
                        break;
                    if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))
                        transfer(tab, nt);
                }
                else if (U.compareAndSwapInt(this, SIZECTL, sc,
                                             (rs << RESIZE_STAMP_SHIFT) + 2))
                    transfer(tab, null);
                s = sumCount();
            }
        }
    }

这部分代码也是比较简单的,不再赘述。

remove 方法实现并发删除

在我们分析完 put 方法的源码之后,相信 remove 方法对你而言就比较轻松了,无非就是先定位再删除的复合。

限于篇幅,我们这里简单的描述下 remove 方法的并发删除过程。

首先遍历整张表的桶结点,如果表还未初始化或者无法根据参数的 hash 值定位到桶结点,那么将返回 null。

如果定位到的桶结点类型是 ForwardingNode 结点,调用 helpTransfer 协助扩容。

否则就老老实实的给桶加锁,删除一个节点。

最后会调用 addCount 方法 CAS 更新 baseCount 的值。

size

size 方法的作用是为我们返回哈希表中实际存在的键值对的总数。

    public int size() {
        long n = sumCount();
        return ((n < 0L) ? 0 :(n > (long)Integer.MAX_VALUE) ? Integer.MAX_VALUE :(int)n);
    }
    CounterCell[] as = counterCells; CounterCell a;
    long sum = baseCount;
    if (as != null) {
        for (int i = 0; i < as.length; ++i) {
            if ((a = as[i]) != null)
                sum += a.value;
        }
    }
    return sum;

可能你会有所疑问,ConcurrentHashMap 中的 baseCount 属性不就是记录的所有键值对的总数吗?直接返回它不就行了吗?

之所以没有这么做,是因为我们的 addCount 方法用于 CAS 更新 baseCount,但很有可能在高并发的情况下,更新失败,那么这些节点虽然已经被添加到哈希表中了,但是数量却没有被统计。

还好,addCount 方法在更新 baseCount 失败的时候,会调用 fullAddCount 将这些失败的结点包装成一个 CounterCell 对象,保存在 CounterCell 数组中。那么整张表实际的 size 其实是 baseCount 加上 CounterCell 数组中元素的个数。

get

get 方法可以根据指定的键,返回对应的键值对,由于是读操作,所以不涉及到并发问题。源码也是比较简单的。

public V get(Object key) {
        Node<K,V>[] tab; Node<K,V> e, p; int n, eh; K ek;
        int h = spread(key.hashCode());
        if ((tab = table) != null && (n = tab.length) > 0 &&
            (e = tabAt(tab, (n - 1) & h)) != null) {
            if ((eh = e.hash) == h) {
                if ((ek = e.key) == key || (ek != null && key.equals(ek)))
                    return e.val;
            }
            else if (eh < 0)
                return (p = e.find(h, key)) != null ? p.val : null;
            while ((e = e.next) != null) {
                if (e.hash == h &&
                    ((ek = e.key) == key || (ek != null && key.equals(ek))))
                    return e.val;
            }
        }
        return null;
    }

总结

三个字,给跪了!

参考文献

Java7/8 中的 HashMap 和 ConcurrentHashMap 全解析 作者非常强,觉得很多地方讲的很好就偷过来了。