JDK 源码解析 :ConcurrentHashMap 1.7

388 阅读9分钟

结构

图摘自网络

ConcurrentHashMap是一个Segment数组。而一个Segement是一个HashEntry数组。而一个HashEntry则存储具体的key-value对。同时还扮演着链表节点的角色,用next执行链表中下一个节点。

ConcurrentHashMap借助segment来提高并发性能,发生在不同segment上的操作不存在线程安全性问题。一个segment可以看做一个小型的hashmap。ConcurentHashMap可以看作是由多个segment组成的一个联邦。 put,get,remove,定位到具体的segment然后交由segment去操作。

与扩容有关的capacity, threshold, loadfactor自然都是针对每一个segment而言的了。扩容同样也只发生在一个segmet中。

Segment

static class Segment<K,V> extends ReentrantLock implements Serializable {
  transient volatile HashEntry<K,V>[] table;
  transient int count;
  transient int modCount;
  
}

Put

put操作定位到key对应的segemnt。(语句1),然后调用2返回key对应的segment(如果不存在该方法会初始化段) 最后将put委托给segment操作(语句3)。

    public V put(K key, V value) {
        Segment<K,V> s;
        if (value == null)
            throw new NullPointerException();
        int hash = hash(key.hashCode());
        int j = (hash >>> segmentShift) & segmentMask;
        if ((s = (Segment<K,V>)UNSAFE.getObject          // nonvolatile; recheck
             (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment  // 1
            s = ensureSegment(j);  //2
        return s.put(key, hash, value, false);  //3
    }

ensureSegment

初始化段的时候没有加锁,所以这里存在竞争的情况。采用CAS来保证线程安全性。

 @SuppressWarnings("unchecked")
    private Segment<K,V> ensureSegment(int k) {
        final Segment<K,V>[] ss = this.segments;
        long u = (k << SSHIFT) + SBASE; // raw offset
        Segment<K,V> seg;
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
            Segment<K,V> proto = ss[0]; // use segment 0 as prototype
            int cap = proto.table.length;
            float lf = proto.loadFactor;
            int threshold = (int)(cap * lf);
            HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
            if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                == null) { // recheck
                Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
                while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                       == null) {
                    if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                        break;
                }
            }
        }
        return seg;
    }

Segment内部类中的Put方法

语句0首先调用scanAndForPut方法获取锁,这个方法会返回一个待插入的节点node。该方法的细节见后。

获取锁之后,就定位到table中的槽,然后遍历链表。

语句1处说明找到了key节点,则更新value。

语句2处如果node不为NULL,则直接利用该节点,设置该节点的next指针指向链表头节点(头插法插入新节点)

如果node为NULL,则此处新建一个节点。

语句3处segment中的节点数量加一,如果此时节点数目超过了阈值,则调用rehash方法扩容。

注意,到现在为止,我们仅仅完成了待插入节点的相关设置,我们还需要将table数组中对应的槽指向node(头插法插入新节点)。因此我们将node节点作为参数传到rehash方法交由他完成。

语句4处,如果不需要扩容,则调用setEntryAt(tab, index, node)来将node放到槽位中。

 final V put(K key, int hash, V value, boolean onlyIfAbsent) {
            HashEntry<K,V> node = tryLock() ? null : //0
                scanAndLockForPut(key, hash, value);
            V oldValue;
            try {
                HashEntry<K,V>[] tab = table;
                int index = (tab.length - 1) & hash;
                HashEntry<K,V> first = entryAt(tab, index);
                for (HashEntry<K,V> e = first;;) {
                    if (e != null) { 
                        K k;
                        if ((k = e.key) == key ||
                            (e.hash == hash && key.equals(k))) {
                            oldValue = e.value;
                            if (!onlyIfAbsent) {
                                e.value = value;     // 1
                                ++modCount;
                            }
                            break;
                        }
                        e = e.next;
                    }
                    else {
                        if (node != null)      //2
                            node.setNext(first);
                        else
                            node = new HashEntry<K,V>(hash, key, value, first);
                        int c = count + 1;
                        if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                            rehash(node);   // 3
                        else         
                            setEntryAt(tab, index, node); //4
                        ++modCount;
                        count = c;
                        oldValue = null;
                        break;
                    }
                }
            } finally {
                unlock();
            }
            return oldValue;
        }
        

scanAndLockForPut

scanAndLockForPut这个方法用于插入时获取segment的锁。采用自旋的方式去获取锁,

当获取不到的时候,这个方法并没有闲着,而是沿着链表遍历找到put应该插入(或者更新)的位置。这个目的在于提前将链表节点加载进本地缓存中,这样等获取锁之后可以直接从缓存中读取,省去了cache miss 的开销。在这里,直接调用equals而不是先比较hash因为这里的遍历速度不重要。如果找到了插入的位置,用node提前分配好应该插入的节点。等到获取锁成功了返回node。如果找到了更新的位置(即链表中存在相同key的节点)那么返回的node 为NULL。(语句1)

用retries为负表示在寻找插入(更新)位置。当找位置的时候,循环会一直在语句1中执行。寻找完毕,retries为0。这时候会进入语句2,3。

语句2在每次自旋中增加retries次数判断是否达到上限,

语句3会在每次自旋次数为偶数时,检查hash对应的槽的链表的表头是否改变了,如果改变了,要重新遍历链表。 这是因为我们想尽可能的遍历到链表的新状态。否则如果其他持有锁的线程进行rehash等操作。导致链表状态大变。那么我们之前的遍历结果对于改善cache miss的效果便下降很多。

语句3 (retries & 1) == 0 还有一个特点是,对于单核处理器,retries的最大值为1。因此当第二次自旋时,不会去检查链表头是否被改变。因为我们这时候如果不去阻塞自己,那么便会占用持有锁的线程的CPU资源,降低了性能。


        private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
            HashEntry<K,V> first = entryForHash(this, hash);
            HashEntry<K,V> e = first;
            HashEntry<K,V> node = null;
            int retries = -1; // negative while locating node
            while (!tryLock()) {
                HashEntry<K,V> f; // to recheck first below
                if (retries < 0) {  // 1
                    if (e == null) {
                        if (node == null) // speculatively create node
                            node = new HashEntry<K,V>(hash, key, value, null);
                        retries = 0;
                    }
                    else if (key.equals(e.key))
                        retries = 0;
                    else
                        e = e.next;
                }
                else if (++retries > MAX_SCAN_RETRIES) {  //2
                    lock();
                    break;
                }
                else if ((retries & 1) == 0 &&  //3
                         (f = entryForHash(this, hash)) != first) {
                    e = first = f; // re-traverse if entry changed
                    retries = -1;
                }
            }
            return node;
        }
    

entryForHash方法用于返回指定segment中,hash对应的槽(bucket)。

    @SuppressWarnings("unchecked")
    static final <K,V> HashEntry<K,V> entryForHash(Segment<K,V> seg, int h) {
        HashEntry<K,V>[] tab;
        return (seg == null || (tab = seg.table) == null) ? null :
            (HashEntry<K,V>) UNSAFE.getObjectVolatile
            (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
    }

扩容rehash

扩容当且仅当put时发现节点数量超过阈值时,因此扩容时已经获取了segment对应的锁,无线程安全性问题。

扩容和hashmap扩容操作基本相同。扩容包含1. 重新生成table数组 2. 将旧table上节点迁移到新table上。

语句2寻找一个lastRun开头的链表,这个链表上的节点的新位置都相同,因此转移的时候只需要将lastRun开头的整个链表放到槽中。节约了复制的开销。

语句2.5处,将lastRun初始化为链表头的next节点,而不是链表头。(这里也可以初始化为链表头。)

语句3处从旧链表头部遍历到lastRun。复制其中的每一个节点到新的位置。

        /**
         * Doubles size of table and repacks entries, also adding the
         * given node to new table
         */
        @SuppressWarnings("unchecked")
        private void rehash(HashEntry<K,V> node) {
        
            HashEntry<K,V>[] oldTable = table;
            int oldCapacity = oldTable.length;
            int newCapacity = oldCapacity << 1; //1
            threshold = (int)(newCapacity * loadFactor);
            HashEntry<K,V>[] newTable =
                (HashEntry<K,V>[]) new HashEntry[newCapacity];
            int sizeMask = newCapacity - 1;
            for (int i = 0; i < oldCapacity ; i++) {
                HashEntry<K,V> e = oldTable[i];
                if (e != null) {
                    HashEntry<K,V> next = e.next;
                    int idx = e.hash & sizeMask;  //新的索引
                    if (next == null)   //  Single node on list
                        newTable[idx] = e;
                    else { // Reuse consecutive sequence at same slot  //2
                        HashEntry<K,V> lastRun = e;
                        int lastIdx = idx;
                        for (HashEntry<K,V> last = next;  //2.5
                             last != null;
                             last = last.next) {
                            int k = last.hash & sizeMask;
                            if (k != lastIdx) {
                                lastIdx = k;
                                lastRun = last;
                            }
                        }
                        newTable[lastIdx] = lastRun;
                        // Clone remaining nodes
                        for (HashEntry<K,V> p = e; p != lastRun; p = p.next) { //3
                            V v = p.value;
                            int h = p.hash;
                            int k = h & sizeMask;
                            HashEntry<K,V> n = newTable[k];
                            newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                        }
                    }
                }
            }
            int nodeIndex = node.hash & sizeMask; // add the new node
            node.setNext(newTable[nodeIndex]);
            newTable[nodeIndex] = node;
            table = newTable;
        }

get

get操作无需加锁,获取key对应的segment,获取segemnt对应的table,然后遍历table上槽对应的链表即可。

    public V get(Object key) {
        Segment<K,V> s; // manually integrate access methods to reduce overhead
        HashEntry<K,V>[] tab;
        int h = hash(key.hashCode());
        long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
        if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
            (tab = s.table) != null) {
            for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
                     (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
                 e != null; e = e.next) {
                K k;
                if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                    return e.value;
            }
        }
        return null;
    }

remove

    public V remove(Object key) {
        int hash = hash(key.hashCode());
        Segment<K,V> s = segmentForHash(hash);
        return s == null ? null : s.remove(key, hash, null);
    }

Segemnt的remove

    final V remove(Object key, int hash, Object value) {
            if (!tryLock())
                scanAndLock(key, hash);
            V oldValue = null;
            try {
                HashEntry<K,V>[] tab = table;
                int index = (tab.length - 1) & hash;
                HashEntry<K,V> e = entryAt(tab, index);
                HashEntry<K,V> pred = null;
                while (e != null) {
                    K k;
                    HashEntry<K,V> next = e.next;
                    if ((k = e.key) == key ||
                        (e.hash == hash && key.equals(k))) {
                        V v = e.value;
                        if (value == null || value == v || value.equals(v)) {
                            if (pred == null)
                                setEntryAt(tab, index, next);
                            else
                                pred.setNext(next);
                            ++modCount;
                            --count;
                            oldValue = v;
                        }
                        break;
                    }
                    pred = e;
                    e = next;
                }
            } finally {
                unlock();
            }
            return oldValue;
        }

Size

可能返回一个过期值

遍历MAp中节点数量的方法有2种:一种是不加锁算2次,比较这2次的modCount是否变化来判断是否存在竞争。 另一种是当发现竞争时转而采用给每一个segment加锁来计算。size()方法,isEmpty()方法都采用了这种。

语句2 遍历segement数组,累加每一个segemnt的size与modCount。同时检查size是否溢出

语句3 检查modCount与之前的modCount是否一样。为了保证安全性,会至少遍历2轮,如果这一轮与上一轮modCount总和一样的话,就会退出循环。返回遍历的结果。如果不一样,则重新遍历,并递增retries。

当重试次数达到最大值时,就转为加锁计算size的方法。因此,size方法在高并发竞争激烈的环境下,非常影响性能。

    public int size() {
        // Try a few times to get accurate count. On failure due to
        // continuous async changes in table, resort to locking.
        final Segment<K,V>[] segments = this.segments;
        int size;
        boolean overflow; // true if size overflows 32 bits
        long sum;         // sum of modCounts
        long last = 0L;   // previous sum
        int retries = -1; // first iteration isn't retry
        try {
            for (;;) {
                if (retries++ == RETRIES_BEFORE_LOCK) {   //1
                    for (int j = 0; j < segments.length; ++j)
                        ensureSegment(j).lock(); // force creation
                }
                sum = 0L;
                size = 0;
                overflow = false;
                for (int j = 0; j < segments.length; ++j) {  //2
                    Segment<K,V> seg = segmentAt(segments, j);
                    if (seg != null) {
                        sum += seg.modCount;
                        int c = seg.count;
                        if (c < 0 || (size += c) < 0) 
                            overflow = true;
                    }
                }
                if (sum == last)  // 3
                    break;
                last = sum;
            }
        } finally {
            if (retries > RETRIES_BEFORE_LOCK) {
                for (int j = 0; j < segments.length; ++j)
                    segmentAt(segments, j).unlock();
            }
        }
        return overflow ? Integer.MAX_VALUE : size;
    }

Ref

  1. developpaper.com/learn-more-…
  2. altair.cs.oswego.edu/pipermail/c…
  3. stackoverflow.com/questions/2…