结构
图摘自网络
ConcurrentHashMap是一个Segment数组。而一个Segement是一个HashEntry数组。而一个HashEntry则存储具体的key-value对。同时还扮演着链表节点的角色,用next执行链表中下一个节点。
ConcurrentHashMap借助segment来提高并发性能,发生在不同segment上的操作不存在线程安全性问题。一个segment可以看做一个小型的hashmap。ConcurentHashMap可以看作是由多个segment组成的一个联邦。 put,get,remove,定位到具体的segment然后交由segment去操作。
与扩容有关的capacity, threshold, loadfactor自然都是针对每一个segment而言的了。扩容同样也只发生在一个segmet中。
Segment
static class Segment<K,V> extends ReentrantLock implements Serializable {
transient volatile HashEntry<K,V>[] table;
transient int count;
transient int modCount;
}
Put
put操作定位到key对应的segemnt。(语句1),然后调用2返回key对应的segment(如果不存在该方法会初始化段) 最后将put委托给segment操作(语句3)。
public V put(K key, V value) {
Segment<K,V> s;
if (value == null)
throw new NullPointerException();
int hash = hash(key.hashCode());
int j = (hash >>> segmentShift) & segmentMask;
if ((s = (Segment<K,V>)UNSAFE.getObject // nonvolatile; recheck
(segments, (j << SSHIFT) + SBASE)) == null) // in ensureSegment // 1
s = ensureSegment(j); //2
return s.put(key, hash, value, false); //3
}
ensureSegment
初始化段的时候没有加锁,所以这里存在竞争的情况。采用CAS来保证线程安全性。
@SuppressWarnings("unchecked")
private Segment<K,V> ensureSegment(int k) {
final Segment<K,V>[] ss = this.segments;
long u = (k << SSHIFT) + SBASE; // raw offset
Segment<K,V> seg;
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
Segment<K,V> proto = ss[0]; // use segment 0 as prototype
int cap = proto.table.length;
float lf = proto.loadFactor;
int threshold = (int)(cap * lf);
HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) { // recheck
Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) {
if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
break;
}
}
}
return seg;
}
Segment内部类中的Put方法
语句0首先调用scanAndForPut方法获取锁,这个方法会返回一个待插入的节点node。该方法的细节见后。
获取锁之后,就定位到table中的槽,然后遍历链表。
语句1处说明找到了key节点,则更新value。
语句2处如果node不为NULL,则直接利用该节点,设置该节点的next指针指向链表头节点(头插法插入新节点)
如果node为NULL,则此处新建一个节点。
语句3处segment中的节点数量加一,如果此时节点数目超过了阈值,则调用rehash方法扩容。
注意,到现在为止,我们仅仅完成了待插入节点的相关设置,我们还需要将table数组中对应的槽指向node(头插法插入新节点)。因此我们将node节点作为参数传到rehash方法交由他完成。
语句4处,如果不需要扩容,则调用setEntryAt(tab, index, node)来将node放到槽位中。
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
HashEntry<K,V> node = tryLock() ? null : //0
scanAndLockForPut(key, hash, value);
V oldValue;
try {
HashEntry<K,V>[] tab = table;
int index = (tab.length - 1) & hash;
HashEntry<K,V> first = entryAt(tab, index);
for (HashEntry<K,V> e = first;;) {
if (e != null) {
K k;
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
oldValue = e.value;
if (!onlyIfAbsent) {
e.value = value; // 1
++modCount;
}
break;
}
e = e.next;
}
else {
if (node != null) //2
node.setNext(first);
else
node = new HashEntry<K,V>(hash, key, value, first);
int c = count + 1;
if (c > threshold && tab.length < MAXIMUM_CAPACITY)
rehash(node); // 3
else
setEntryAt(tab, index, node); //4
++modCount;
count = c;
oldValue = null;
break;
}
}
} finally {
unlock();
}
return oldValue;
}
scanAndLockForPut
scanAndLockForPut这个方法用于插入时获取segment的锁。采用自旋的方式去获取锁,
当获取不到的时候,这个方法并没有闲着,而是沿着链表遍历找到put应该插入(或者更新)的位置。这个目的在于提前将链表节点加载进本地缓存中,这样等获取锁之后可以直接从缓存中读取,省去了cache miss 的开销。在这里,直接调用equals而不是先比较hash因为这里的遍历速度不重要。如果找到了插入的位置,用node提前分配好应该插入的节点。等到获取锁成功了返回node。如果找到了更新的位置(即链表中存在相同key的节点)那么返回的node 为NULL。(语句1)
用retries为负表示在寻找插入(更新)位置。当找位置的时候,循环会一直在语句1中执行。寻找完毕,retries为0。这时候会进入语句2,3。
语句2在每次自旋中增加retries次数判断是否达到上限,
语句3会在每次自旋次数为偶数时,检查hash对应的槽的链表的表头是否改变了,如果改变了,要重新遍历链表。 这是因为我们想尽可能的遍历到链表的新状态。否则如果其他持有锁的线程进行rehash等操作。导致链表状态大变。那么我们之前的遍历结果对于改善cache miss的效果便下降很多。
语句3 (retries & 1) == 0 还有一个特点是,对于单核处理器,retries的最大值为1。因此当第二次自旋时,不会去检查链表头是否被改变。因为我们这时候如果不去阻塞自己,那么便会占用持有锁的线程的CPU资源,降低了性能。
private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
HashEntry<K,V> first = entryForHash(this, hash);
HashEntry<K,V> e = first;
HashEntry<K,V> node = null;
int retries = -1; // negative while locating node
while (!tryLock()) {
HashEntry<K,V> f; // to recheck first below
if (retries < 0) { // 1
if (e == null) {
if (node == null) // speculatively create node
node = new HashEntry<K,V>(hash, key, value, null);
retries = 0;
}
else if (key.equals(e.key))
retries = 0;
else
e = e.next;
}
else if (++retries > MAX_SCAN_RETRIES) { //2
lock();
break;
}
else if ((retries & 1) == 0 && //3
(f = entryForHash(this, hash)) != first) {
e = first = f; // re-traverse if entry changed
retries = -1;
}
}
return node;
}
entryForHash方法用于返回指定segment中,hash对应的槽(bucket)。
@SuppressWarnings("unchecked")
static final <K,V> HashEntry<K,V> entryForHash(Segment<K,V> seg, int h) {
HashEntry<K,V>[] tab;
return (seg == null || (tab = seg.table) == null) ? null :
(HashEntry<K,V>) UNSAFE.getObjectVolatile
(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
}
扩容rehash
扩容当且仅当put时发现节点数量超过阈值时,因此扩容时已经获取了segment对应的锁,无线程安全性问题。
扩容和hashmap扩容操作基本相同。扩容包含1. 重新生成table数组 2. 将旧table上节点迁移到新table上。
语句2寻找一个lastRun开头的链表,这个链表上的节点的新位置都相同,因此转移的时候只需要将lastRun开头的整个链表放到槽中。节约了复制的开销。
语句2.5处,将lastRun初始化为链表头的next节点,而不是链表头。(这里也可以初始化为链表头。)
语句3处从旧链表头部遍历到lastRun。复制其中的每一个节点到新的位置。
/**
* Doubles size of table and repacks entries, also adding the
* given node to new table
*/
@SuppressWarnings("unchecked")
private void rehash(HashEntry<K,V> node) {
HashEntry<K,V>[] oldTable = table;
int oldCapacity = oldTable.length;
int newCapacity = oldCapacity << 1; //1
threshold = (int)(newCapacity * loadFactor);
HashEntry<K,V>[] newTable =
(HashEntry<K,V>[]) new HashEntry[newCapacity];
int sizeMask = newCapacity - 1;
for (int i = 0; i < oldCapacity ; i++) {
HashEntry<K,V> e = oldTable[i];
if (e != null) {
HashEntry<K,V> next = e.next;
int idx = e.hash & sizeMask; //新的索引
if (next == null) // Single node on list
newTable[idx] = e;
else { // Reuse consecutive sequence at same slot //2
HashEntry<K,V> lastRun = e;
int lastIdx = idx;
for (HashEntry<K,V> last = next; //2.5
last != null;
last = last.next) {
int k = last.hash & sizeMask;
if (k != lastIdx) {
lastIdx = k;
lastRun = last;
}
}
newTable[lastIdx] = lastRun;
// Clone remaining nodes
for (HashEntry<K,V> p = e; p != lastRun; p = p.next) { //3
V v = p.value;
int h = p.hash;
int k = h & sizeMask;
HashEntry<K,V> n = newTable[k];
newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
}
}
}
}
int nodeIndex = node.hash & sizeMask; // add the new node
node.setNext(newTable[nodeIndex]);
newTable[nodeIndex] = node;
table = newTable;
}
get
get操作无需加锁,获取key对应的segment,获取segemnt对应的table,然后遍历table上槽对应的链表即可。
public V get(Object key) {
Segment<K,V> s; // manually integrate access methods to reduce overhead
HashEntry<K,V>[] tab;
int h = hash(key.hashCode());
long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
(tab = s.table) != null) {
for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
e != null; e = e.next) {
K k;
if ((k = e.key) == key || (e.hash == h && key.equals(k)))
return e.value;
}
}
return null;
}
remove
public V remove(Object key) {
int hash = hash(key.hashCode());
Segment<K,V> s = segmentForHash(hash);
return s == null ? null : s.remove(key, hash, null);
}
Segemnt的remove
final V remove(Object key, int hash, Object value) {
if (!tryLock())
scanAndLock(key, hash);
V oldValue = null;
try {
HashEntry<K,V>[] tab = table;
int index = (tab.length - 1) & hash;
HashEntry<K,V> e = entryAt(tab, index);
HashEntry<K,V> pred = null;
while (e != null) {
K k;
HashEntry<K,V> next = e.next;
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
V v = e.value;
if (value == null || value == v || value.equals(v)) {
if (pred == null)
setEntryAt(tab, index, next);
else
pred.setNext(next);
++modCount;
--count;
oldValue = v;
}
break;
}
pred = e;
e = next;
}
} finally {
unlock();
}
return oldValue;
}
Size
可能返回一个过期值
遍历MAp中节点数量的方法有2种:一种是不加锁算2次,比较这2次的modCount是否变化来判断是否存在竞争。 另一种是当发现竞争时转而采用给每一个segment加锁来计算。size()方法,isEmpty()方法都采用了这种。
语句2 遍历segement数组,累加每一个segemnt的size与modCount。同时检查size是否溢出
语句3 检查modCount与之前的modCount是否一样。为了保证安全性,会至少遍历2轮,如果这一轮与上一轮modCount总和一样的话,就会退出循环。返回遍历的结果。如果不一样,则重新遍历,并递增retries。
当重试次数达到最大值时,就转为加锁计算size的方法。因此,size方法在高并发竞争激烈的环境下,非常影响性能。
public int size() {
// Try a few times to get accurate count. On failure due to
// continuous async changes in table, resort to locking.
final Segment<K,V>[] segments = this.segments;
int size;
boolean overflow; // true if size overflows 32 bits
long sum; // sum of modCounts
long last = 0L; // previous sum
int retries = -1; // first iteration isn't retry
try {
for (;;) {
if (retries++ == RETRIES_BEFORE_LOCK) { //1
for (int j = 0; j < segments.length; ++j)
ensureSegment(j).lock(); // force creation
}
sum = 0L;
size = 0;
overflow = false;
for (int j = 0; j < segments.length; ++j) { //2
Segment<K,V> seg = segmentAt(segments, j);
if (seg != null) {
sum += seg.modCount;
int c = seg.count;
if (c < 0 || (size += c) < 0)
overflow = true;
}
}
if (sum == last) // 3
break;
last = sum;
}
} finally {
if (retries > RETRIES_BEFORE_LOCK) {
for (int j = 0; j < segments.length; ++j)
segmentAt(segments, j).unlock();
}
}
return overflow ? Integer.MAX_VALUE : size;
}