跳跃表ConcurrentSkipListMap源码解析

676 阅读14分钟

最近在看 Redis 设计与实现,在 Redis 底层数据结构用到了跳跃表,趁着这次需求,看了一下 Java 基于跳跃表实现的集合。

跳跃表

跳跃表(SkipList)是一种有序的数据结构,每个节点维持着多个指向其他节点的指针,从而达到快速访问的目的。大部分情况下,跳跃表的查询效率可以和平衡树媲美,并且实现比平衡树简单,因为得到了广泛的应用,这 里只将ConcurrentSkipListMap的实现。 在这里插入图片描述 如上图是 ConcurrentSkipListMap 可能出现的结构图,接下来我们看 new ConcurrentSkipListMap() 会做什么操作,首先看一下主要代码。

	public ConcurrentSkipListMap() {
		// comparator 是统一比较器
    	this.comparator = null;
    	initialize();
	}
	private void initialize() {
        keySet = null;
        entrySet = null;
        values = null;
        descendingMap = null;
        head = new HeadIndex<K,V>(new Node<K,V>(null, BASE_HEADER, null),
                                  null, null, 1);
    }
    /**头部索引*/
    static final class HeadIndex<K,V> extends Index<K,V> {
        final int level;
        HeadIndex(Node<K,V> node, Index<K,V> down, Index<K,V> right, int level) {
            super(node, down, right);
            this.level = level;
        }
    }
    /**索引*/
    static class Index<K,V> {
        final Node<K,V> node;
        final Index<K,V> down;
        volatile Index<K,V> right;
        
 		Index(Node<K,V> node, Index<K,V> down, Index<K,V> right) {
            this.node = node;
            this.down = down;
            this.right = right;
        }

这里可以看到 new ConcurrentSkipListMap() ,只是创建一个原始头索引,没有分层,没有链表。接下来看他的 put 方法,假如我们要 put key 为 5 的数据。

    public V put(K key, V value) {
        if (value == null)
            throw new NullPointerException();
        return doPut(key, value, false);
    }

只是进行了一下 value 不能为 null 的判断,主要执行逻辑在 doPut() 方法,由于 doPut() 方法太长,这里把该方法拆成三部分

1、创建需要插入的节点

首先找到合适的 Node 位置,注意:这里获得的 Node 并不一定是要插入节点的前置节点,获取到的 Node 满足两个条件:

	1. 位于最底层; 	
	2. 节点的key小于新插入节点的key

然后我们设当前遍历的节点为 b,目前遍历的节点的下个节点为 n, 要新插入的节点为 z ,递归找到要插入点,插入点需要满足两个条件:

	 1. b < z <= n 
 	 2. b 和 n 最接近
private V doPut(K key, V value, boolean onlyIfAbsent) {
        // z 为要新创建的 Node
        Node<K,V> z;             // added node
        if (key == null)
            throw new NullPointerException();
        Comparator<? super K> cmp = comparator;
        outer: for (;;) {
            /*
            * b:目前遍历到的节点,初始为 新插入节点的前置节点
            * n:就是目前遍历到的节点的下一个节点,初始为 新节点没插入之前的 b 的后置节点
            * */
            for (Node<K,V> b = findPredecessor(key, cmp), n = b.next;;) {
                if (n != null) {
                    Object v; int c;
                    // f:n 的后续节点
                    Node<K,V> f = n.next;
                    // 如果 n 不为 b 的后续节点,进行下一次 for (;;) ,可能因为其他线程已经插入了其他节点到b的后续节点
                    if (n != b.next)               // inconsistent read
                        break;
                    // 如果后续节点 n 的值为 null ,表明已经被删除,则删除该节点,用 n 的后续节点 f代替 n ,进行下一次 for (;;)
                    if ((v = n.value) == null) {   // n is deleted
                        n.helpDelete(b, f);
                        break;
                    }
                    // 如果b的 value 为 null,说明被其他线程删除,进行下一次 for (;;)
                    if (b.value == null || v == n) // b is deleted
                        break;
                    // 如果需要插入的 key 大于后续节点的 key,向后续节点推进
                    if ((c = cpr(cmp, key, n.key)) > 0) {
                        b = n;
                        n = f;
                        continue;
                    }
                    // 如果要插入的 key 与后续节点的 key 相同
                    if (c == 0) {
                        // 如果 onlyIfAbsent 为 true,替换 n 的 value 值,结束循环
                        // 如果 cas 替换值失败,进行下一次 for (;;)
                        if (onlyIfAbsent || n.casValue(v, value)) {
                            @SuppressWarnings("unchecked") V vv = (V)v;
                            return vv;
                        }
                        break; // restart if lost race to replace value
                    }
                    // else c < 0; fall through
                }

                // 创建要插入的 node 节点 z , next 指向为 n
                z = new Node<K,V>(key, value, n);
                // 把 b 的 next 指向改为新节点 z , cas 替换失败,进行下一次 for (;;)
                if (!b.casNext(n, z))
                    break;         // restart if lost race to append to b
                // 至此,最底层 Node 链表构成,但对于跳跃表来说,还需要构造上层索引以及其连接关系
                break outer;
            }
        }

其中有几个主要方法,首先看 findPredecessor(key, cmp) 方法,作用是找到合适的 Node 位置

    /**
     * 方法说明:
     *      返回跳跃表中置于底层比较合适的前置节点,仅是一个满足以下两个条件的节点:
     *          1、位于最底层索引;2、节点的key小于新插入节点的key
     *      注意:并不是说返回的节点后面就是新插入的节点,也不是说新插入的节点的后置节点就会是原本q.next
     *      最后几个节点的关系如下 q ---> XXX ---> 新插入的节点 ---> XXX ---> 原本q.next
     *      XXX 代表着可能会有多个间隔
     *  变量说明:
     *      q: 当前索引,并且是需要包含最终结果node的索引
     *      r: 当前索引的后续索引,还是用来判断后续索引的key是否大于需要插入的key
     */
    private Node<K,V> findPredecessor(Object key, Comparator<? super K> cmp) {
        if (key == null)
            throw new NullPointerException(); // don't postpone errors
        for (;;) {
            // 首先获取原始头索引 q,以及当前索引的后续索引 r
            for (Index<K,V> q = head, r = q.right, d;;) {
                // r 为 null ,只能说明,这一层索引到最后也没有比需要插入的 key 大的, 需要从下层继续找
                // 如果不属于最底层,把 q 更新为当前节点的下级节点, r 还是最新节点的下级节点,如果是最底层,就返回当前变量 q 的 node
                if (r != null) {
                    // 获取链表后续索引的节点:n
                    Node<K,V> n = r.node;
                    K k = n.key;
                    // 如果该节点的 value 为 null ,表示该节点已经被删除了,put 时也是不允许 value 为 null 的
                    // 这个判断的作用是把某个 value 为 null 的所有上层索引都 unlink 掉
                    if (n.value == null) {
                        // 删除空值索引,即把r的后续索引顶替掉r,删除失败重新进入内层 for 循环
                        if (!q.unlink(r))
                            break;           // restart
                        // r 再次设置为当前索引的右索引,进行下次内循环
                        r = q.right;         // reread r
                        continue;
                    }
                    // 如果需要插入的key比后续节点n的key大,跳跃到下个节点
                    // q --> r     r ---> r.right
                    // 直到n的key比需要插入的key大,跳出循环
                    if (cpr(cmp, key, k) > 0) {
                        q = r;
                        r = r.right;
                        continue;
                    }
                }
                // 比如此时的索引 q 是最低层索引,返回q的节点
                if ((d = q.down) == null)
                    return q.node;
                // 如果不是最低一层索引,将 q 赋值为 q 的下层索引,
                q = d;
                // r 赋值为 q 的后续索引
                r = d.right;
            }
        }
    }

在这里插入图片描述 上图为寻找返回值的路线,其中蓝线为 q 的指向变化,红线为 r 的指向变化,到最后 key 为 2 上面的 level 1 处的索引满足返回条件,返回 key 为 2 的 Node。

然后是 helpDelete 方法,只有满足 (v = n.value) == null 才会执行该方法,该方法会执行两次,以达到删除该节点,用 n 的后续节点 f 代替 n 的目的。

1. 满足 f == null || f.value != f 执行 然后会返回,调用此方法 达到 
   b ---> n --- > newNode<K,V>(f) --- > f 的情况 		 	
2. 在下次 for 循环执行到这里,会满足 else 判断, 把 b 的 next
   指向 f,n 和 n.next 就不在链表里了
  void helpDelete(Node<K,V> b, Node<K,V> f) {
            /*
             * Rechecking links and then doing only one of the
             * help-out stages per call tends to minimize CAS
             * interference among helping threads.
             */
            // 如果 f 是该 Node 的后续几点,并且 该 Node 节点是 b 的后续节点,一般情况一定满足,除非中途被删除
            if (f == next && this == b.next) {
                // 如果 f 为 null ,或者 f 的 value 不为 f 本身,创建一个新的节点,进行附加标记操作
                // 执行 casNext(f, new Node<K,V>(f)); 之后的结构为
                /*
                * b ---> n --- > new Node<K,V>(f) --- > f
                * */
                if (f == null || f.value != f) // not already marked
                    casNext(f, new Node<K,V>(f));
                else
                // 已经被标记,执行下面语句之后,结构为
                /*
                * b ---> f   n 和 new Node<K,V>(f) 不在链表里了
                * */
                    b.casNext(this, f.next);
            }
        }

在这里插入图片描述 上图是代码执行完第一步之后的情况

2、创建需要插入的节点

首先获取随机层数,从最底层依次向上构建新的 Index,如果随机层数大于原最高 level,最后重新构建 headIndex。


        // 获取随机数
        int rnd = ThreadLocalRandom.nextSecondarySeed();
        // 该随机数的二进制与 0x80000001 的二进制:10000000000000000000000000000001 进行与运算
        // 即:随机数的二进制最高位与最低位都为 0 ,其他位无所谓,如果不满足,不增加节点的层数,直接结束,不再进行第三步
        if ((rnd & 0x80000001) == 0) { // test highest and lowest bits
            // 初始 level 为 1
            int level = 1, max;
            // 判断随机值的二进制从倒数第二位开始向左有多少个连续的 1 ,就 ++level 几次
            while (((rnd >>>= 1) & 1) != 0)
                ++level;
            Index<K,V> idx = null;
            // 头索引 h
            HeadIndex<K,V> h = head;
            // max 赋值为头索引的层数,即目前跳跃表最高的层数;
            // 如果随机出的 level 小等于 max
            if (level <= (max = h.level)) {
                // 循环创建 z 的上层索引,此时的索引只是内部有指向 新节点 z ,然后指向刚创建的下级索引
                // 并没有左右关联到跳跃表中
                for (int i = 1; i <= level; ++i)
                    idx = new Index<K,V>(z, idx, null);
            } else { // try to grow by one level
                // 如果随机出的 level 大于 max,只取 level = max + 1
                level = max + 1; // hold in array and later pick the one to use
                // 创建一个 长度为 level + 1 的数组,因为要让下标从 1 开始,所以 + 1
                @SuppressWarnings("unchecked")Index<K,V>[] idxs =
                    (Index<K,V>[])new Index<?,?>[level+1];

                // 循环创建 z 的上层索引,此时的索引只是内部有指向 新节点 z ,然后指向刚创建的下级索引
                // 并没有左右关联到跳跃表中
                for (int i = 1; i <= level; ++i)
                    idxs[i] = idx = new Index<K,V>(z, idx, null);

                // 重新构建头索引
                for (;;) {
                    // h 重新赋值为 头索引,此变量也是最终的头索引
                    h = head;
                    // oldLevel 赋值为原本的 level
                    int oldLevel = h.level;
                    // 如果 level 小等于 oldLevel 说明其他线程修改了头循环的层数,重for (;;)进行
                    if (level <= oldLevel) // lost race to add level
                        break;
                    // 重新设置一个头索引
                    HeadIndex<K,V> newh = h;
                    // 获取头索引的节点
                    Node<K,V> oldbase = h.node;
                    // 循环创建头索引,一般来说只会循环一次
                    for (int j = oldLevel+1; j <= level; ++j)
                        newh = new HeadIndex<K,V>(oldbase, newh, idxs[j], j);
                    // cas 方式 用 newh 替换 h 做 head 头部节点,并不是 h = newh,竞争失败重新进入for (;;)
                    // 因为新链表层,只有头索引和新节点索引,不需要再次构建指向关系,但下层索引都需要插入新索引,所以做以下操作
                    /* 成功后 head 头部节点的状态:
                        head
                         ↓
                      原本的 head
                     */
                    if (casHead(h, newh)) {
                        // 把 h 赋值为 newh
                        h = newh;
                        // level 赋值为 原本头索引的 level,然后把 idx 赋值为 idxs 的 level 处的索引
                        idx = idxs[level = oldLevel];
                        break;
                    }
                }
            }

在这里插入图片描述

执行完第二步后,,我们假如,level 为 4 ,跳跃表的结构如上图所示。 变化如下:

 	1. 最高 level 变成了 4 
	2. 新构建了一个 headIndex 
	3. 新创建的 headIndex 的 right 指向新建节点的最高层索引,最高层索引的 right 指向 null
	4. 新创建的 Index 都依次往下指向
3、构建每层新增 Index 的指向关系

主要逻辑是从最高是原 level 那一层,依次比较 right,找到合适的位置,构建正确的指向关系,然后开始构建下一层。

   			// insertionLevel:目前需要构建指向关系的层数,初始为原头结点的层数
            splice: for (int insertionLevel = level;;) {
                // j : 初始为目前头索引的层数
                int j = h.level;
                // q :需要构建指向关系的最高层头索引,有可能会变成 q -> q.right -> q.right.right...
                // r :初始为 q 的第一个右索引,有可能会变成 r -> r.right -> r.right.right...
                // t :q 所在那一层的新增的索引
                // 循环构建每一层的新增节点的指向关系,从头索引开始,直到 需要插入的key,小于某个索引的 key
                for (Index<K,V> q = h, r = q.right, t = idx;;) {
                    // 如果 q 或 t 为 null ,表示其它线程删除了 q 或 t,重新进入 for (;;)
                    if (q == null || t == null)
                        break splice;
                    // 如果 r 为空,说明是该层最大的值
                    if (r != null) {
                        // n 赋值为 r 的 node
                        Node<K,V> n = r.node;
                        // compare before deletion check avoids needing recheck
                        // 把要插入的 key 和 n 的 key 比较
                        int c = cpr(cmp, key, n.key);
                        // put 时也是不允许 value 为 null 的,value 为 null ,表示已经被删除了, 删除空值索引
                        if (n.value == null) {
                            if (!q.unlink(r))
                                break;
                            r = q.right;
                            continue;
                        }
                        // 如果我们插入的key大于n的key,继续向后续推进
                        if (c > 0) {
                            q = r;
                            r = r.right;
                            continue;
                        }
                    }

                    // 如果目前头索引的层数与原头索引的层数相等,也就是层数没有变化
                    if (j == insertionLevel) {
                        // 构建后的指向关系为 : q ---> t ---> r,如果失败执行下一次 for (int insertionLevel = level;;)
                        if (!q.link(r, t))
                            break; // restart
                        // 如果新增的节点值为 null 标识该节点已被其他线程删除,执行下一次 for (int insertionLevel = level;;)
                        if (t.node.value == null) {
                            findNode(key);
                            break splice;
                        }
                        // 逐层自减,到最底层退出循环,完成 put 操作
                        if (--insertionLevel == 0)
                            break splice;
                    }

                    // q、r、t 随着节点层数下移而下移,准备下层构建操作
                    if (--j >= insertionLevel && j < level)
                        t = t.down;
                    q = q.down;
                    r = q.right;
                }
            }
        }
        return null;
    }

其中有一个 findNode(key) 方法,这里加上注释

private Node<K,V> findNode(Object key) {
        if (key == null)
            throw new NullPointerException(); // don't postpone errors
        Comparator<? super K> cmp = comparator;
        outer: for (;;) {
            // 找到目标节点的前置节点 b,n 为 b 的后置节点
            for (Node<K,V> b = findPredecessor(key, cmp), n = b.next;;) {
                Object v; int c;
                // 如果后置节点为 null ,结束操作
                if (n == null)
                    break outer;
                // 获取后续节点的后续节点
                Node<K,V> f = n.next;
                // 如果 n 部位前置节点的后续节点,说明已经被删除,进入下次 for (;;)
                if (n != b.next)                // inconsistent read
                    break;
                // 后续节点 n 的 value 为 null ,说明已经被删除,将 b 的 next 设置为 f ,进入下次 for (;;)
                if ((v = n.value) == null) {    // n is deleted
                    n.helpDelete(b, f);
                    break;
                }
                // 前置节点 b 的 value 为 null ,说明已经被删除,进入下次 for (;;)
                if (b.value == null || v == n)  // b is deleted
                    break;
                // 如果目标 key 与 n.key 相等,返回后续节点 n
                if ((c = cpr(cmp, key, n.key)) == 0)
                    return n;
                // 如果大于后续节点 key ,向后推进
                if (c < 0)
                    break outer;
                b = n;
                n = f;
            }
        }
        return null;
    }

至此,跳跃表 put 完成。 在这里插入图片描述 接下来看 get 方法,整体流程为:

	1、调用  findPredecessor 方法,从头索引向右开始查找,如果后续索引的节点 key 大于我们要查找的 key,则头索引向下移,
	   在下层 Index 查询,一直找到没有下层索引位置,返回 Node
	2、到这里可能有 value 为 null 的空值索引,表明已经被删除,用 CAS 删除这些无用节点
	3、从  findPredecessor 方法找到的 Node 开始向右遍历,直到某个节点的 key 与目标的 key 相等,返回结果,如果 小于目
	   标 key ,直接返回 null
 public V get(Object key) {
        return doGet(key);
    }
   private V doGet(Object key) {
        if (key == null)
            throw new NullPointerException();
        Comparator<? super K> cmp = comparator;
        outer: for (;;) {
            /*
             * b:目前遍历到的节点,初始为 新插入节点的前置节点
             * n:就是目前遍历到的节点的下一个节点,初始为 新节点没插入之前的 b 的后置节点
             * */
            for (Node<K,V> b = findPredecessor(key, cmp), n = b.next;;) {
                Object v; int c;
                if (n == null)
                    break outer;
                // f : n 的后置节点
                Node<K,V> f = n.next;
                // 如果 n 不为 b 的后续节点,进行下一次 for (;;) ,可能因为其他线程已经插入了其他节点到b的后续节点
                if (n != b.next)                // inconsistent read
                    break;
                // 如果后续节点 n 的值为 null ,表明已经被删除,则删除该节点,用 n 的后续节点 f代替 n ,进行下一次 for (;;)
                if ((v = n.value) == null) {    // n is deleted
                    n.helpDelete(b, f);
                    break;
                }
                // 如果b的 value 为 null,说明被其他线程删除,进行下一次 for (;;)
                if (b.value == null || v == n)  // b is deleted
                    break;
                // 想要查找的 key 与 n 的 key 相等,返回结果
                if ((c = cpr(cmp, key, n.key)) == 0) {
                    @SuppressWarnings("unchecked") V vv = (V)v;
                    return vv;
                }
                // 如果小于,直接退出 for 循环,返回 null
                if (c < 0)
                    break outer;
                // 如果需要插入的 key 大于后续节点的 key,向后续节点推进
                b = n;
                n = f;
            }
        }
        return null;
    }

最后是 remove 方法,这里需要说明一下,因为 ConcurrentSkipListMap 是支持并发的,因此再删除节点的时候可能会有其他线程在该位置进行插入,所以会在要删除的节点后面增加一个特殊节点进行标记,然后才会进行删除,解决在正在删除的后续新增数据,然后这个数据被删除掉的问题。 remove 方法主要分为三步:

 1. findPredecessor 方法获取合适的前置节点 b
 2. 获取 b 的后置节点 n ,然后是一系列处理并发的 CAS 操作,接着比较 n 的 key 和要删除的 key,如果要删除的 key 大
    于 n 的key 继续往后遍历,小于的话表明没有对应的 key ,直接结束
 2. 如果 key 相等,先把 n 的 value 置 null ,然后把b 指向 n.next,表示 n 在节点的链表里已经被删除,接着调用 
    findPredecessor 把 删除的节点 n 的索引在每一层索引层删除。
   public V remove(Object key) {
        return doRemove(key, null);
    }
final V doRemove(Object key, Object value) {
        if (key == null)
            throw new NullPointerException();
        Comparator<? super K> cmp = comparator;
        outer: for (;;) {
            /*
             * 还是 findPredecessor 方法
             * b:目前遍历到的节点,初始为 新插入节点的前置节点
             * n:就是目前遍历到的节点的下一个节点,初始为 新节点没插入之前的 b 的后置节点
             * */
            for (Node<K,V> b = findPredecessor(key, cmp), n = b.next;;) {
                Object v; int c;
                // n 为 null ,表明已经被删除,退出循环
                if (n == null)
                    break outer;
                // f : n 的后续节点
                Node<K,V> f = n.next;
                // n 不为 b 的 后续节点了,读取不一致,进行下次 for (;;)
                if (n != b.next)                    // inconsistent read
                    break;
                // n 的 value 为 null,进行标记或删除操作,达到 删除 n 节点的操作
                if ((v = n.value) == null) {        // n is deleted
                    n.helpDelete(b, f);
                    break;
                }
                // 说明其他线程删除了 b,或者已经被标记为要删除,进入下次 for (;;)
                // helpDelete 方法里会先进行标记操作,来标记要被删除,相关指向关系为如下时,就是被标记为要被删除
                //                          ↓
                // 要被删除的 Node ---> new Node<K,V>(Node 的后续节点) ---> 要被删除的 Node 原本的后续节点
                if (b.value == null || v == n)      // b is deleted
                    break;
                // 比较目标 key  与 n 的 key
                // 小于 n 的 key 表明没有对应节点,结束删除操作
                if ((c = cpr(cmp, key, n.key)) < 0)
                    break outer;
                // 大于 0 往后续节点遍历
                if (c > 0) {
                    b = n;
                    n = f;
                    continue;
                }
                // 下面是相等的情况下的操作
                // value != null 需要判断 value 是否相同才会进行删除操作
                if (value != null && !value.equals(v))
                    break outer;
                // 首先把 n 的 value 设为 null,还没有从链表里把 n 删除,失败了继续下次 for (;;)
                if (!n.casValue(v, null))
                    break;
                // appendMarker 为附加标记,标记 n 节点要被删除,执行后相关指向为: b ---> n --- > new Node<K,V>(f) --- > f
                // 然后把 b 的后续节点指向 f,此时在跳跃表里已没有任何节点指向 n 节点,但 n 节点的索引 Index 还在跳跃表里
                if (!n.appendMarker(f) || !b.casNext(n, f))
                    findNode(key);                  // retry via findNode
                else {
                    // 此方法不仅是用来找合适的前置节点,其中有个 unlink 方法还会把空值索引给取消关联
                    // 这里执行完就会把 n 的索引在跳跃表里删除
                    findPredecessor(key, cmp);      // clean index
                    if (head.right == null)
                        tryReduceLevel();
                }
                @SuppressWarnings("unchecked") V vv = (V)v;
                return vv;
            }
        }
        return null;
    }

总结

由于 ConcurrentSkipListMap 保存的是键值对,所以使用 Node 来保存数据,并组成完整的数据链表,不参与构建跳跃表结构。主要是通过 Index 来实现跳跃表,每个 Index 都有一个 Node 的指向,就是说 Index 作为索引,是用来加快查询效率,Node 才是真正存储数据的。 并且我们可以发现,在 put 、remove 方法上都没有锁的参与,都是通过 CAS + for 循环完成的,所以 ConcurrentSkipListMap 的效率是很快的,如果应用需要有序性,那么跳表是一个很好的选择。