jdk7—ConcurrenHashMap源码解读

337 阅读5分钟

并发安全的选择

本文基于在了解HashMap的前提下进行分析,故建议先看HashMap源码:juejin.cn/post/685457…

已经知道HashMap是并发不安全的,那么该如何实现HashMap的并发安全?

这里有两种实现方式:1、采用HashTable的方式 2、采用ConcurrentHashMap

然而,尽管HashTable是并发安全的,但在实际应用中却不会使用它。

这里管中窥豹,从最常用的put方法来分析为什么不使用HashTable。

public synchronized V put(K key, V value) {
    // Make sure the value is not null
    if (value == null) {
        throw new NullPointerException();
    }

    // Makes sure the key is not already in the hashtable.
    Entry tab[] = table;
    int hash = hash(key);
    int index = (hash & 0x7FFFFFFF) % tab.length;
    for (Entry<K,V> e = tab[index] ; e != null ; e = e.next) {
        if ((e.hash == hash) && e.key.equals(key)) {
            V old = e.value;
            e.value = value;
            return old;
        }
    }

    modCount++;
    if (count >= threshold) {
        // Rehash the table if the threshold is exceeded
        rehash();

        tab = table;
        hash = hash(key);
        index = (hash & 0x7FFFFFFF) % tab.length;
    }

    // Creates the new entry.
    Entry<K,V> e = tab[index];
    tab[index] = new Entry<>(hash, key, value, e);
    count++;
    return null;
}

由此可见,HashTable在使用put方法时,相当于对这个对象加了锁,这种方法尽管能避免同时将元素加在同一位置造成的并发不安全问题,然而效率却十分低下。试想有两个元素同时put进HashTable中,尽管两元素并没有冲突,却仍然要等待锁的获得。

因此在面临并发安全问题时,我们通常选择使用ConcurrentHashMap。

ConcurrentHashMap

基本结构

ConcurrentHashMap采用的是分段锁,其基本结构不同于HashMap。Segment继承自ReentrantLock。

public ConcurrentHashMap() {
    this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}

默认初始容量:16。不同于HashMap,这个值指的是所有HashEntry的长度和。即共有16个Entry对象。

默认加载因子:0.75。

并发级别:16。指的是Segment的个数。

构造方法

public ConcurrentHashMap(int initialCapacity,
                         float loadFactor, int concurrencyLevel) {
    if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
        throw new IllegalArgumentException();
    if (concurrencyLevel > MAX_SEGMENTS)
        concurrencyLevel = MAX_SEGMENTS;
    int sshift = 0;
    int ssize = 1;
    //类似于HashMap中的初始化操作,size为大于当前传入并发水平的最小的2的幂次方数,并以此创建Segment数组对象。
    while (ssize < concurrencyLevel) {
        //size左移的次数。2^sshift=ssize
        ++sshift;
        ssize <<= 1;
    }
    //用于put方法,获取高sshift的二进制数
    this.segmentShift = 32 - sshift;
    //类似于HashMap,用于取下标
    this.segmentMask = ssize - 1;
    if (initialCapacity > MAXIMUM_CAPACITY)
        initialCapacity = MAXIMUM_CAPACITY;
    //用于计算Segment下Entry数组的长度
    int c = initialCapacity / ssize;
    //相当于向上取整
    if (c * ssize < initialCapacity)
        ++c;
   	//MIN_SEGMENT_TABLE_CAPACITY=2,可见,即使默认initialCapacity和concurrencyLevel均为16,但要求了Segment的最小长度为2
    int cap = MIN_SEGMENT_TABLE_CAPACITY;
    while (cap < c)
        cap <<= 1;
    //对一个Segment对象的初始化,当需要用到时,直接将其置于Segment[]数组中。相当于一个原型。
    Segment<K,V> s0 =
        new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                         (HashEntry<K,V>[])new HashEntry[cap]);
    //根据sszie创建Segment数组对象
    Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
    //将s0放在ss的第0个位置
    UNSAFE.putOrderedObject(ss, SBASE, s0); 
    this.segments = ss;
}

put方法

大致流程:

(key,value)---->key.hashcode--->hashcode
    //Segment[] 数组的下标地址
int index=hashcode&segment[].length-1;
//Entry[]数组的下标地址
int index1=hashcode&Entry[].length-1
//放入Segment数组中
public V put(K key, V value) {
    Segment<K,V> s;
    if (value == null)
        throw new NullPointerException();
    //获取hash值
    int hash = hash(key);
    //Segment[]数组下标的位置
    int j = (hash >>> segmentShift) & segmentMask;
    //取Segment[]数组第j个位置的值
    if ((s = (Segment<K,V>)UNSAFE.getObject         
         (segments, (j << SSHIFT) + SBASE)) == null) 
        //如果第j个位置为空,生成一个Segment对象
        s = ensureSegment(j);
    return s.put(key, hash, value, false);
}
//31-为1的最高位前0的个数
SSHIFT = 31 - Integer.numberOfLeadingZeros(ss);

为什么Segment[]下的下标位置采用: (hash >>> segmentShift) & segmentMask?

个人认为还是为了增加散列性,但可能效果没那么明显,对于一个key其hashcode的前几位用于确定这个key在Segment数组的位置,后几位用于确定这个key在Entry数组的位置,增强hashcode的利用率。

//多次判断,保证赋值时并发安全,提高效率
private Segment<K,V> ensureSegment(int k) {
    final Segment<K,V>[] ss = this.segments;
    long u = (k << SSHIFT) + SBASE;
    Segment<K,V> seg;
    //防止并发安全问题,再次判断是否为null
    if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
        //获取原型对象及其属性
        Segment<K,V> proto = ss[0]; 
        int cap = proto.table.length;
        float lf = proto.loadFactor;
        int threshold = (int)(cap * lf);
        HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
            == null) {
            //真正创建一个Segment对象
            Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
            while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                   == null) {
                //判断ss数组的第u个位置是否为null,如果为null则将s赋给seg,并返回
                if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                    break;
            }
        }
    }
    return seg;
}
//放入Entry数组中
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
    //trylock()尝试加锁
    HashEntry<K,V> node = tryLock() ? null :
        scanAndLockForPut(key, hash, value);
    V oldValue;
    try {
        HashEntry<K,V>[] tab = table;
        int index = (tab.length - 1) & hash;
        //下标index处的头节点
        HashEntry<K,V> first = entryAt(tab, index);
        //该节点遍历链表
        for (HashEntry<K,V> e = first;;) {
            if (e != null) {
                K k;
                if ((k = e.key) == key ||
                    (e.hash == hash && key.equals(k))) {
                    oldValue = e.value;
                    //如果查到相同的key,是否更新
                    if (!onlyIfAbsent) {
                        e.value = value;
                        ++modCount;
                    }
                    break;
                }
                e = e.next;
            }
            //执行条件:1、first为空 2、遍历到链尾尾
            else {
                if (node != null)
                    //头插法
                    node.setNext(first);
                else
                    //生成一个对象,头插法
                    node = new HashEntry<K,V>(hash, key, value, first);
                int c = count + 1;
                if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                    //扩容
                    rehash(node);
                else
                    //将node放在tab数组的index下标处
                    setEntryAt(tab, index, node);
                //操作次数
                ++modCount;
                //链表长度
                count = c;
                oldValue = null;
                break;
            }
        }
    } finally {
        unlock();
    }
    return oldValue;
}
final void setNext(HashEntry<K,V> n) {
    UNSAFE.putOrderedObject(this, nextOffset, n);
}

trylock和lock的区别

相同处:都是为了获取锁

trylock不会阻塞,根据是否获得锁,返回对应的布尔值。

lock会阻塞。

private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
    HashEntry<K,V> first = entryForHash(this, hash);
    HashEntry<K,V> e = first;
    HashEntry<K,V> node = null;
    int retries = -1; 
    //尝试请求锁失败
    while (!tryLock()) {
        HashEntry<K,V> f;
        if (retries < 0) {
            if (e == null) {
                if (node == null) 
                    //初始化节点,在此处只是预热相关代码,无太大意义
                    node = new HashEntry<K,V>(hash, key, value, null);
                retries = 0;
            }
            else if (key.equals(e.key))
                retries = 0;
            else
                e = e.next;
        }
        //当重试次数超过最大值时,直接请求锁,根据情况进入阻塞或获得该锁
        else if (++retries > MAX_SCAN_RETRIES) {
            lock();
            break;
        }
        //retries为偶时,判断此时头节点是否与先前进入的头节点一致,由此判断是否有并发问题,如果不一致,则更新first的值,并且将retries置-1,该线程重新判断
        else if ((retries & 1) == 0 &&
                 (f = entryForHash(this, hash)) != first) {
            e = first = f;
            retries = -1;
        }
    }
    return node;
}

rehash方法

扩容是单个Segment对象下的Entry[]进行扩容,而非整个Segment数组下的所有Entry[]进行扩容。

private void rehash(HashEntry<K,V> node) {
    //旧数组
    HashEntry<K,V>[] oldTable = table;
    //旧数组长度
    int oldCapacity = oldTable.length;
    //新数组长度
    int newCapacity = oldCapacity << 1;
    //获取新的阈值
    threshold = (int)(newCapacity * loadFactor);
    //生成新数组
    HashEntry<K,V>[] newTable =
        (HashEntry<K,V>[]) new HashEntry[newCapacity];
    //用于求下标
    int sizeMask = newCapacity - 1;
    //遍历旧数组,对旧数组元素的转移
    for (int i = 0; i < oldCapacity ; i++) {
        HashEntry<K,V> e = oldTable[i];
        //该数组下标不为null
        if (e != null) {
            //下一个节点
            HashEntry<K,V> next = e.next;
            //扩容后的下标
            int idx = e.hash & sizeMask;
            //next为null时,直接将e存储在新数组的指定位置处
            if (next == null)  
                newTable[idx] = e;
            //当不为null时,用lastRun指定最后一个链表的头节点,该节点在新数组的下标不同于旧数组,同时其每一个next节点所在数组下标都与该头节点相同;再遍历该头节点前的节点在新的数组中的下标
            else { 
                HashEntry<K,V> lastRun = e;
                int lastIdx = idx;
                for (HashEntry<K,V> last = next;
                     last != null;
                     last = last.next) {
                    int k = last.hash & sizeMask;
                    if (k != lastIdx) {
                        lastIdx = k;
                        lastRun = last;
                    }
                }
                newTable[lastIdx] = lastRun;
                for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                    V v = p.value;
                    int h = p.hash;
                    int k = h & sizeMask;
                    HashEntry<K,V> n = newTable[k];
                    newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                }
            }
        }
    }
    //对传入node节点位置的插入
    int nodeIndex = node.hash & sizeMask; 
    node.setNext(newTable[nodeIndex]);
    newTable[nodeIndex] = node;
    table = newTable;
}

get方法

该方法比较简单,故提供源码,不再赘述

public V get(Object key) {
    Segment<K,V> s; // manually integrate access methods to reduce overhead
    HashEntry<K,V>[] tab;
    int h = hash(key);
    long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
    if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
        (tab = s.table) != null) {
        for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
                 (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
             e != null; e = e.next) {
            K k;
            if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                return e.value;
        }
    }
    return null;
}

size()方法

该方法首先会在不锁Segment对象的情况下,获取该对象下的元素的个数,同时判断在获取个数的同时,对该Segment有无其他修改操作,如果有则重新计算。如果多次出现在求size的同时有其他线程进行修改操作的话,就改为先获取锁,然后统计。

public int size() {
    final Segment<K,V>[] segments = this.segments;
    int size;
    boolean overflow; // true if size overflows 32 bits
    long sum;         // sum of modCounts
    long last = 0L;   // previous sum
    int retries = -1; // first iteration isn't retry
    try {
        for (;;) {
            //RETRIES_BEFORE_LOCK默认为2 ,当重试次数等于该值时,直接加锁,进行统计
            if (retries++ == RETRIES_BEFORE_LOCK) {
                for (int j = 0; j < segments.length; ++j)
                    ensureSegment(j).lock(); 
            }
            sum = 0L;
            size = 0;
            overflow = false;
            for (int j = 0; j < segments.length; ++j) {
                Segment<K,V> seg = segmentAt(segments, j);
                if (seg != null) {
                    //获取segment下的修改次数
                    sum += seg.modCount;
                    int c = seg.count;
                    if (c < 0 || (size += c) < 0)
                        overflow = true;
                }
            }
            //即该Segment没有进行修改操作时。退出循环
            if (sum == last)
                break;
            //赋segment最新修改的次数
            last = sum;
        }
    } finally {
        if (retries > RETRIES_BEFORE_LOCK) {
            for (int j = 0; j < segments.length; ++j)
                segmentAt(segments, j).unlock();
        }
    }
    return overflow ? Integer.MAX_VALUE : size;
}