ConcurrentHashMap源码解析

489 阅读9分钟

前言

本文对于ConcurrentHashMap的源码解析是基于JDK1.8,JDK1.8的ConcurrentHashMap相较于JDK1.7差别还是蛮大的,在底层数据结构改变的基础上,锁机制也有了很强大的优化,由于红黑树的操作需要一些前置知识,后面会专门写一篇介绍红黑树的文章

ConcurrentHashMap源码解析

构造器

ConcurrentHashMap的数组是延迟加载的,所有构造器都是对参数做处理,没有初始化数组

// 空参构造器没有做任何处理
public ConcurrentHashMap() {
}

// 这个构造器比较重要,会在传入的初始容量的基础上找到比当前容量大的最接近2的幂次方,比如32,初始容量就是64
public ConcurrentHashMap(int initialCapacity) {
    if (initialCapacity < 0)
      throw new IllegalArgumentException();
    int cap = ((initialCapacity >= (MAXIMUM_CAPACITY >>> 1)) ?
               MAXIMUM_CAPACITY :
               tableSizeFor(initialCapacity + (initialCapacity >>> 1) + 1));
    this.sizeCtl = cap;
}
重要的成员变量和方法
// concurrentHashMap中存储元素的数组
transient volatile Node<K,V>[] table;

// 扩容时,会将原数组中的数据转移到nextTable
private transient volatile Node<K,V>[] nextTable;

// 单线程情况,维护数组元素的个数
private transient volatile long baseCount;

// 多线程情况下,这个数组加上baseCount共同维护map数组元素的个数
private transient volatile CounterCell[] counterCells;

// sizeCtl是一个非常重要的变量,具有非常丰富的含义
// 1.sizeCtl > 0,如果数组未初始化,那么记录的是数组的初始容量,如果数组已经初始化,那么记录的是数组的扩容阈值
// 2.sizeCtl = -1,表示数组正在进行初始化
// 3.sizeCtl < 0 and sizeCtl != -1,表示数组正在扩容,关于这个标志位的解释不要看代码注释,代码注释是错的
private transient volatile int sizeCtl;

put方法
static final int HASH_BITS = 0x7fffffff;

// 充分利用hash的高低位进行运算,table较小时,降低hash碰撞概率,与HASH_BITS进行&的作用是屏蔽hashcode的第一位(符号位)
// spread返回的值为非负数
static final int spread(int h) {
  	return (h ^ (h >>> 16)) & HASH_BITS;
}

添加元素时,只会对某个桶位进行加锁,不会影响其它桶位,提高了并发效率

final V putVal(K key, V value, boolean onlyIfAbsent) {
  	// concurrentHashMap key或者value都不能为null
        if (key == null || value == null) throw new NullPointerException();
  	// 对key的hashcode进行一定的扰动,并屏蔽符号位
        int hash = spread(key.hashCode());
  	// 记录数组中某个桶位元素的个数
        int binCount = 0;
  	// 注意:这里是一个for循环,for循环中代码较多,容易忘记
        for (Node<K,V>[] tab = table;;) {
            Node<K,V> f; int n, i, fh;
            if (tab == null || (n = tab.length) == 0)
              	// 当数组还未初始化时,先初始化数组,concurrentHashMap,数组是延迟初始化的
                tab = initTable();
            // 这个桶位没有元素
            else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
              	// CAS设置元素,设置成功退出for循环
                if (casTabAt(tab, i, null,
                             new Node<K,V>(hash, key, value, null)))
                    break;                  
            }
            // 如果桶位中该元素的hash值为-1,表示数组正在扩容中,数组扩容下文有详细的解释
            else if ((fh = f.hash) == MOVED)
              	// 进行协助扩容
                tab = helpTransfer(tab, f);
            else {
                V oldVal = null;
              	// 对该桶位进行加锁,保证线程安全,不会影响到其它桶位
                synchronized (f) {
                    // double check,防止该节点树化,这个节点如果还是原来的节点表示该节点没有更改过
                    if (tabAt(tab, i) == f) {
                      	// 普通链表节点
                        if (fh >= 0) {
                            // 注意:这里是从1开始
                            binCount = 1;
                            for (Node<K,V> e = f;; ++binCount) {
                                K ek;
                              	// 找到了一个与插入元素key完全一致的数据,进行value的更新
                                if (e.hash == hash &&
                                    ((ek = e.key) == key ||
                                     (ek != null && key.equals(ek)))) {
                                    oldVal = e.val;
                                    if (!onlyIfAbsent)
                                        e.val = value;
                                    break;
                                }
                                Node<K,V> pred = e;
                              	// 没有找到与插入元素key完全一致的数据,链表后新增节点
                                if ((e = e.next) == null) {
                                    pred.next = new Node<K,V>(hash, key,
                                                              value, null);
                                    break;
                                }
                            }
                        }
                      	// 树节点
                        else if (f instanceof TreeBin) {
                            Node<K,V> p;
                            binCount = 2;
                            // 新增树节点
                            if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key,
                                                           value)) != null) {
                                oldVal = p.val;
                                if (!onlyIfAbsent)
                                    p.val = value;
                            }
                        }
                    }
                }
                if (binCount != 0) {
                    // 如果当前数组的长度大于等于64且链表上的节点超过9个,进行树化(binCount从1开始)
                    if (binCount >= TREEIFY_THRESHOLD)
                        treeifyBin(tab, i);
                    // 重复键,将旧值返回
                    if (oldVal != null)
                        return oldVal;
                    break;
                }
            }
        }
  	// 添加的是新元素,维护数组元素个数
        addCount(1L, binCount);
        return null;
    }

put方法流程图

image.png

数组的初始化

private final Node<K,V>[] initTable() {
      Node<K,V>[] tab; int sc;
      // while会不断进行循环,相当于不断自旋,通过自旋+CAS,进行数组的初始化,保证线程安全
      while ((tab = table) == null || tab.length == 0) {
        // 当sizeCtl < 0 时,数组要么在初始化,要么在扩容
        if ((sc = sizeCtl) < 0)
          // 此时说明有其它线程正在进行初始化,该线程让出CPU执行权
          Thread.yield(); 
        // 将sizeCtl置为-1,表示数组正在初始化
        else if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) {
          try {
            // 做一次double check,避免重复初始化
            if ((tab = table) == null || tab.length == 0) {
              // 如果sc > 0,取sc,否则取默认容量
              int n = (sc > 0) ? sc : DEFAULT_CAPACITY;
              @SuppressWarnings("unchecked")
              // 数组初始化,容量为n
              Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
              table = tab = nt;
              // 计算扩容阈值,n >>> 2 表示将n/4,n-n/4=0.75n,这个设计非常巧妙,计算起来非常高效
              sc = n - (n >>> 2); 
            }
          } finally {
            // 扩容完成,修改sizeCtl,此时sizeCtl表示数组扩容阈值
            sizeCtl = sc;
          }
          break;
        }
      }
      return tab;
}
addCount方法:维护数组元素数量
private final void addCount(long x, int check) {
        CounterCell[] as; long b, s;
  	// 这里有两种情况
  	// 1.counterCells为null,说明此前不存在多线程竞争,对baseCount进行CAS累加即可,累加成功即可进行下面的操作
  	// 2.counterCells为null,但是对baseCount进行CAS累加失败了;或者counterCells不为null,都说明要么此前已经存在多线程
  	// 竞争,要么当前存在多线程竞争,就要使用counterCells来维护数组元素个数
        if ((as = counterCells) != null ||
            !U.compareAndSwapLong(this, BASECOUNT, b = baseCount, s = b + x)) {
            CounterCell a; long v; int m;
            boolean uncontended = true;
            // 这里有几个判断
            // 1.counterCells未初始化
            // 2.counterCells数组中某个元素的值为空
            // 3.counterCells数组中某个元素的值不为空,但是对它进行CAS累加的失败失败了,说明出现了并发问题
            if (as == null || (m = as.length - 1) < 0 ||
                (a = as[ThreadLocalRandom.getProbe() & m]) == null ||
                !(uncontended =
                  U.compareAndSwapLong(a, CELLVALUE, v = a.value, v + x))) {
                fullAddCount(x, uncontended);
                return;
            }
            // 链表长度<=1,不需要进行扩容检查
            if (check <= 1)
                return;
            // 获取数组元素个数
            s = sumCount();
        }
        if (check >= 0) {
            Node<K,V>[] tab, nt; int n, sc;
            // 这里有几个判断
            // 1.数组元素个数达到扩容阈值
            // 2.数组不为空
            // 3.数组长度小于限定的最大值
            // 满足这3个条件,进行扩容
            while (s >= (long)(sc = sizeCtl) && (tab = table) != null &&
                   (n = tab.length) < MAXIMUM_CAPACITY) {
              	// 下文有关于这个函数的解释
                int rs = resizeStamp(n);
              	// 扩容时,第一次有线程走到这里sc是不可能小于0的,此时sc为扩容的阈值
              	// 当sc < 0,说明此时有线程正在进行扩容
                if (sc < 0) {
                    // 这里有几个判断
                    // 1.(nt = nextTable) == null表示扩容结束
                    // 2.transferIndex <= 0 表示
                    if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
                        sc == rs + MAX_RESIZERS || (nt = nextTable) == null ||
                        transferIndex <= 0)
                        break;
                    // 更新协助扩容的线程的数量
                    if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))
                        transfer(tab, nt);
                }
              	// rs 左移动16位会变成很大的负数,具体解释见下文
              	// 将sc变成一个小于0的值
                else if (U.compareAndSwapInt(this, SIZECTL, sc,
                                             (rs << RESIZE_STAMP_SHIFT) + 2))
                    // 进行扩容
                    transfer(tab, null);
                s = sumCount();
            }
        }
    }
static final int resizeStamp(int n) {
   // Integer.numberOfLeadingZeros(n):该方法的作用是返回无符号整数i的最高非0位前面的0的个数
   // 比如说,10的二进制表示为 0000 0000 0000 0000 0000 0000 0000 1010,java的整型长度为32位。那么这个方法返回的就是28
   return Integer.numberOfLeadingZeros(n) | (1 << (RESIZE_STAMP_BITS - 1));
}


00000000 00000000 00000000 00000001
00000000 00000000 10000000 000000001左移15位
--------------------------------------------
  
00000000 00000000 10000000 00000000
00000000 00000000 00000000 00010000  
00000000 00000000 10000000 00010000 | 完之后的值
  
扩容时,会将sizeCtl置为负数,具体操作是这样子
--------------------------------------------
00000000 00000000 10000000 00010000
10000000 00010000 00000000 00000000  rs << RESIZE_STAMP_SHIFT rs 左移动16位会变成很大的负数
  
--------------------------------------------
10000000 00010000 00000000 00000000  再对这个数 +2,
所以sizeCtl < 0 and sizeCtl != -1表示数组正在扩容,扩容的线程数为sizeCtl的低16位减1

fullAddCount:这个方法的主要作用是想要通过baseCount和CounterCell数组共同维护数组中元素的个数

这个方法非常复杂,我们先来了解一下基本思想:单线程情况下,我们只使用baseCount一个变量就足够维护数组元素的个数了,但是在并发量比较高的情况下,有多个线程共同更新baseCount,就会导致其中一个线程更新失败,所以大神就想出了一个办法,多创造几个值共同维护,线程更新baseCount失败了,就去更新数组中某个桶位的值就可以了

image.png

image.png

private final void fullAddCount(long x, boolean wasUncontended) {
        int h;
        if ((h = ThreadLocalRandom.getProbe()) == 0) {
            ThreadLocalRandom.localInit();      // force initialization
            h = ThreadLocalRandom.getProbe();
            // 重新生成了Probe值,wasUncontended冲突标志位设置为true
            wasUncontended = true;
        }
        boolean collide = false;                // True if last slot nonempty
        for (;;) {
            CounterCell[] as; CounterCell a; int n; long v;
            if ((as = counterCells) != null && (n = as.length) > 0) {
              	// counterCells数组中该位置的对象还未初始化
                if ((a = as[(n - 1) & h]) == null) {
                    // cellsBusy是一个标志位,表示CounterCell数组是否处于添加元素的状态
                    if (cellsBusy == 0) {            
                      	// 先创建CounterCell对象,并且把x赋值给了这个对象,x为1 
                        CounterCell r = new CounterCell(x); // Optimistic create
                      	// 判断cellsBusy是否为0,为0表示CounterCell数组为空闲状态
                      	// 将cellsBusy CAS为1,表示CounterCell目前处于添加元素的状态
                      	// 保证线程安全
                        if (cellsBusy == 0 &&
                            U.compareAndSwapInt(this, CELLSBUSY, 0, 1)) {
                            boolean created = false;
                            try {               
                                CounterCell[] rs; int m, j;
                              	// 不断做double check,往数组中添加元素
                                if ((rs = counterCells) != null &&
                                    (m = rs.length) > 0 &&
                                    rs[j = (m - 1) & h] == null) {
                                    rs[j] = r;
                                    // created 元素是否添加成功的标志位
                                    created = true;
                                }
                            } finally {
                              	// 将cellsBusy置为0,表示CounterCell再次处于空闲状态
                                cellsBusy = 0;
                            }
                            // 数组元素添加成功,跳出循环
                            if (created)
                                break;
                            continue;           
                        }
                    }
                    collide = false;
                }
                // wasUncontended这个是一个传参,wasUncontended为false,说明之前一次进行CAS累加失败了
                else if (!wasUncontended) 
                    // 更新标识位,再次进行自旋累加
                    wasUncontended = true;     
              	// 如果CounterCell数组中该桶位已经创建了对象了,直接对该对象进行累加操作即可
                else if (U.compareAndSwapLong(a, CELLVALUE, v = a.value, v + x))
                    break;
              	// 这里有两种情况
              	// 1.counterCells不等于as,说明有其它线程对counterCells进行了修改
              	// 2.n >= NCPU,说明并发的线程数已经超过CPU的数量,要进行限制
                else if (counterCells != as || n >= NCPU)
                    collide = false;            // At max size or stale
                else if (!collide)
                    collide = true;
              	// 对counterCells进行扩容
                else if (cellsBusy == 0 &&
                         U.compareAndSwapInt(this, CELLSBUSY, 0, 1)) {
                    try {
                      	// double check
                        if (counterCells == as) {
                            // 扩容为当前的两倍
                            CounterCell[] rs = new CounterCell[n << 1];
                            for (int i = 0; i < n; ++i)
                                rs[i] = as[i];
                            counterCells = rs;
                        }
                    } finally {
                      	// 将cellsBusy置为0,表示CounterCell再次处于空闲状态
                        cellsBusy = 0;
                    }
                    collide = false;
                    // 再次进行自旋
                    continue;                   
                }
                h = ThreadLocalRandom.advanceProbe(h);
            }
            // 如果counterCells还未初始化,进行初始化操作
            else if (cellsBusy == 0 && counterCells == as &&
                     U.compareAndSwapInt(this, CELLSBUSY, 0, 1)) {
                boolean init = false;
                try {                          
                    if (counterCells == as) {
                      	// 初始化长度为2
                        CounterCell[] rs = new CounterCell[2];
                        rs[h & 1] = new CounterCell(x);
                        counterCells = rs;
                        init = true;
                    }
                } finally {
                    cellsBusy = 0;
                }
                if (init)
                    break;
            }
            // 如果counterCells靠不住,还是对baseCount进行CAS累加
            else if (U.compareAndSwapLong(this, BASECOUNT, v = baseCount, v + x))
                break;                          
        }
    }

sumCount:这个方法非常简单,就是将baseCount和CounterCell数组进行累加

final long sumCount() {
    CounterCell[] as = counterCells; CounterCell a;
    long sum = baseCount;
    if (as != null) {
      for (int i = 0; i < as.length; ++i) {
        if ((a = as[i]) != null)
          sum += a.value;
      }
    }
    return sum;
}
transfer:扩容

concurrentHashMap有一个协助扩容的概念,什么是协助扩容,就是当一个线程已经启动了扩容,这时候,有另外一个线程要对这个数组进行操作,并且这个操作的桶位已经进行了迁移或者正在迁移中,大神就不会让这个线程干等,而是会让这个线程协助扩容。

有两个地方会触发协助扩容:

  • 添加元素时,发现添加的元素对应的桶位是fwd节点
  • 添加完元素后,发现元素总数已经达到了阈值,并且sizeCtl的值小于0

协助扩容规则:首先原数组的迁移是从最后的一个索引开始往前迁移的, 往前迁移时,每次会给一个线程分配迁移的区域,默认是16个位置,直至迁移完成

image.png

private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
    int n = tab.length, stride;
    // NCPU > 1:该服务器有多个CPU,可以进行多线程协助扩容,(n >>> 3) / NCPU这样计算出来的值如果小于16,则取16
    // NCPU <= 1:该服务器只有1个CPU,不会有多线程协助扩容,直接由一个线程做完所有扩容即可
    if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE)
      stride = MIN_TRANSFER_STRIDE; // subdivide range
    // 如果是扩容线程,此时新数组为null
    if (nextTab == null) {            // initiating
      try {
        @SuppressWarnings("unchecked")
        // 创建新数组,容量为之前的两倍
        Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1];
        nextTab = nt;
      } catch (Throwable ex) {      
        sizeCtl = Integer.MAX_VALUE;
        return;
      }
      // nextTable为全局变量,记录扩容后的数组
      nextTable = nextTab;
      // 记录线程开始迁移的桶位,从后往前迁移
      transferIndex = n;
    }
    // 记录新数组的末尾
    int nextn = nextTab.length;
    // 迁移过的节点会值会替换成fwd,表示该节点已经被迁移过
    ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab);
    boolean advance = true;
    boolean finishing = false; // to ensure sweep before committing nextTab
    // 注意,这里又是一个for死循环
    // i表示当前正在迁移桶位的索引
    // bound表示下一次任务迁移的开始桶位
    // 下面会进行赋值
    for (int i = 0, bound = 0;;) {
      Node<K,V> f; int fh;
      // 这个while是给线程指定分配区域的
      while (advance) {
        int nextIndex, nextBound;
        // 第一个线程走到这里的时候,--i不可能大于bound,finishing为false,所以判断不成立,不会走这里
        // finishing为true时表示迁移结束
        // --i >= bound表示没有需要继续迁移的桶位了
        if (--i >= bound || finishing)
          advance = false;
        // 第一个线程走到这里的时候,transferIndex记录开始迁移的桶位,不会<0,所以不会走这里
        // (nextIndex = transferIndex) <= 0 表示没有需要迁移的桶位了,就不用继续分配任务了
        else if ((nextIndex = transferIndex) <= 0) {
          i = -1;
          advance = false;
        }
        // 第一个线程进来,会走这个判断
        // 分配迁移的任务,每一次分配16个数组的长度
        else if (U.compareAndSwapInt
                 (this, TRANSFERINDEX, nextIndex,
                  nextBound = (nextIndex > stride ?
                               nextIndex - stride : 0))) {
          bound = nextBound;
          i = nextIndex - 1;
          advance = false;
        }
      }
      // 当前线程的扩容任务完成
      if (i < 0 || i >= n || i + n >= nextn) {
        int sc;
        // finishing为true表示所有线程的扩容任务都完成了
        if (finishing) {
          nextTable = null;
          table = nextTab;
          // 重新计算阈值,并且赋值给sizeCtl
          // n << 1 = 2n 
          // n >>> 1 = n/2
          // 2n - n/2 = 1.5n 就是0.75原来阈值的两倍
          sizeCtl = (n << 1) - (n >>> 1);
          return;
        }
        // 有线程完成了扩容任务,将扩容线程数-1
        if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
          // resizeStamp(n) << RESIZE_STAMP_SHIFT + 2 看回之前的代码,这个是扩容前sc的值
          // 如果sc=扩容前的值,表示所有扩容线程的任务都完成了
          // 如果sc!=扩容前的值,表示还有扩容线程的任务未完成
          // 需要进行自旋
          if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
            return;
          // 所有扩容线程的任务完成,finishing设置为true
          finishing = advance = true;
          i = n; // recheck before commit
        }
      }
      // 当前要迁移的桶位没有元素,直接将该桶位值替换为fwd
      else if ((f = tabAt(tab, i)) == null)
        advance = casTabAt(tab, i, null, fwd);
      // 当前要迁移的桶位已经迁移过了
      else if ((fh = f.hash) == MOVED)
        advance = true; // already processed
      else {
        // 开始迁移当前节点,加锁防止迁移过程中有元素添加进来
        synchronized (f) {
          // double check
          if (tabAt(tab, i) == f) {
            Node<K,V> ln, hn;
            if (fh >= 0) {
              // 注意,以下代码使用的n是原数组的数组长度
              int runBit = fh & n;
              Node<K,V> lastRun = f;
              // 从头节点开始循环,找到lastRun节点
              for (Node<K,V> p = f.next; p != null; p = p.next) {
                int b = p.hash & n;
                if (b != runBit) {
                  runBit = b;
                  lastRun = p;
                }
              }
              if (runBit == 0) {
                ln = lastRun;
                hn = null;
              }
              else {
                hn = lastRun;
                ln = null;
              }
              // 从头节点开始循环,使用头插法,将原链表分成两条链表
              for (Node<K,V> p = f; p != lastRun; p = p.next) {
                int ph = p.hash; K pk = p.key; V pv = p.val;
                if ((ph & n) == 0)
                  ln = new Node<K,V>(ph, pk, pv, ln);
                else
                  hn = new Node<K,V>(ph, pk, pv, hn);
              }
              // 低位链表放在扩容后数组的原位置
              setTabAt(nextTab, i, ln);
              // 高位链表放在扩容后数组的原位置+n
              setTabAt(nextTab, i + n, hn);
              // 对旧数组已经迁移过的桶位设置标识位
              setTabAt(tab, i, fwd);
              advance = true;
            }
            else if (f instanceof TreeBin) {
              TreeBin<K,V> t = (TreeBin<K,V>)f;
              TreeNode<K,V> lo = null, loTail = null;
              TreeNode<K,V> hi = null, hiTail = null;
              int lc = 0, hc = 0;
              for (Node<K,V> e = t.first; e != null; e = e.next) {
                int h = e.hash;
                TreeNode<K,V> p = new TreeNode<K,V>
                  (h, e.key, e.val, null, null);
                if ((h & n) == 0) {
                  if ((p.prev = loTail) == null)
                    lo = p;
                  else
                    loTail.next = p;
                  loTail = p;
                  ++lc;
                }
                else {
                  if ((p.prev = hiTail) == null)
                    hi = p;
                  else
                    hiTail.next = p;
                  hiTail = p;
                  ++hc;
                }
              }
              ln = (lc <= UNTREEIFY_THRESHOLD) ? untreeify(lo) :
              (hc != 0) ? new TreeBin<K,V>(lo) : t;
              hn = (hc <= UNTREEIFY_THRESHOLD) ? untreeify(hi) :
              (lc != 0) ? new TreeBin<K,V>(hi) : t;
              setTabAt(nextTab, i, ln);
              setTabAt(nextTab, i + n, hn);
              setTabAt(tab, i, fwd);
              advance = true;
            }
          }
        }
      }
    }
}

链表的扩容迁移思想其实跟HashMap是一模一样的:

  • 将原链表分为两条链表,低位链表放在新数组的原索引位处,高位链表放在新数组的原索引位+n的位置
  • 将链表中的节点分为两组的依据是将节点的hash值与原数组的长度n进行&操作,结果0放在低位链表中,结果1放在高位链表中
  • 迁移的过程中是使用头插法

这里比较不好理解的是lastRun节点,第一次for循环就是为了找到lastRun节点,lastRun结点实际上是最后几个具备相同 p.hash & n 值的连续结点的最上边结点,因为这样可以减少该结点下边几个结点的迁移工作

image.png