跳跃表 Skip List

235 阅读3分钟

  链表由于其简单以及易于实现而为人所熟悉,也常常被应用到实际开发中。其各种操作的时间复杂度分别为

  • 查询:O(n)O(n)
  • 插入:O(1)O(1)
  • 删除:O(1)O(1)

链表

   可以看出,链表的搜索/查询速度相对比较慢。因为链表中元素的查询每次必须从第一个节点开始,每次只能遍历一个节点。但 SkipList 却可以实现节点的跳跃,可以降低查询操作的时间复杂度。

跳跃表(Skip List)

   跳跃表是一种概率数据结构,是对链表数据结构的扩展。在链表数据结构中,每一个节点指向其右侧相邻的节点。跳跃表在此基础上进行了扩展,除了指向右侧相邻节点的指针外(基础层级),部分节点还有其他指针指向右侧与其间隔 kkk0k ≥ 0) 个节点的节点,指针所在的层级越高,层级上的节点数量越少。

⓵ 完美跳跃表(Perfect Skip List)

   所谓完美跳跃表,是指随着指针层级(l(l0)l(l ≥ 0))的增加,每两个相邻节点之间的间隔 2l2^l 个节点。

完美跳跃表

   完美跳跃表具有以下特性:

  • 如果节点的个数为 nn ,则层数为 ceil(log2n)ceil(\log_2 n)
  • ll 层的节点个数只有第 l1l - 1 层节点个数的一半

  这样,查询的时间复杂度降低为 O(logn)O(\log n) ,查询的性能得到了提升,但节点的插入和删除变得非常困难,因为每插入/删除一个节点,节点之间的指针都需要重新设置一遍。

⓶ 随机跳跃表

  完美跳跃表虽然提高了查询的性能,但节点的插入和删除却非常难于维护。所以,我们只能期望一个跳跃表满足完美跳跃表的要求,但实际每个节点包含的指针层级只能随机,这就是随机跳跃表。

   具体每个插入跳跃表的节点包含的指针层级应该如何确定,这里以抛硬币为例,在硬币落地为背面之前,连续得到正面的次数即为节点所包含的指针的层级数。同时,假设当前跳跃表中的节点个数为 nn ,那么理论上,在插入一个新的节点之后,跳跃表的层级不应该超过 ceil(log2(n+1))ceil(\log_2 (n + 1)) 。所以,新插入的节点所包含的指针的层级数应该取二者中较小的一个。

随机跳跃表

⓷ 代码实现

import random
import math


class Node:
    def __init__(self, key=-1, level=0):
        self.key = key
        self.next = [None] * (level + 1)

    def __repr__(self):
        return str(self.__dict__)

    def __str__(self):
        return str(self.__dict__)


class SkipList:
    def __init__(self):
        self.header = Node()
        self.length = 0
        self.maxLevel = 0

    def get_level(self):
        """
        随机获取新插入的节点应该包含的指针的层级
        但理论上不应该超过跳跃表所包含的节点个数的对数
        """
        level = 0
        while random.randint(1, 2) != 2 and level < math.log2(self.length + 1):
            level += 1

        return level

    def get_target(self, key):
        """
        取得执行插入/删除/查询操作的位置
        """
        target = [None] * (self.maxLevel + 1)
        node = self.header

        for i in reversed(range(self.maxLevel + 1)):
            while node.next[i] is not None and node.next[i].key < key:
                node = node.next[i]
            target[i] = node

        return target

    def search_node(self, key):
        target = self.get_target(key)

        if len(target) > 0:
            if target[0].next[0].key == key:
                return target[0].next[0]

        return None

    def insert_node(self, key):
        """
        节点插入
        """
        target = self.get_target(key)

        # 只有要插入的元素在跳跃表中不存在或跳跃表为空时才执行插入
        if target[0].next[0] is None or key != target[0].next[0].key:
            level = self.get_level()
            node = Node(key, level)
            min_level = min(level, self.maxLevel)

            for i in range(min_level + 1):
                node.next[i] = target[i].next[i]
                target[i].next[i] = node

            # 如果新插入的节点包含的指针层级超过了当前跳跃表的最高层级,那么应该相应的增加 header 节点的层级
            if self.maxLevel < level:
                self.header.next.extend([None] * (level - self.maxLevel))
                for i in range(self.maxLevel + 1, level + 1):
                    self.header.next[i] = node
                self.maxLevel = level
            self.length += 1

    def delete_node(self, key):
        """
        节点删除
        """
        target = self.get_target(key)

        if target[0].next[0].key == key:
            node = target[0].next[0]
            for i in range(len(node.next)):
                target[i].next[i] = node.next[i]
                # 删除节点可能伴随着跳跃表最高层级降低
                if self.header.next[i] is None:
                    self.maxLevel -= 1
            self.length -= 1

        return

    def print_node(self):
        print("print node start")
        for i in reversed(range(self.maxLevel + 1)):
            print("level = {}".format(i))
            node = self.header
            while node.next[i] is not None:
                node = node.next[i]
                print(node.key, end=' ')
            print()


lst = SkipList()
lst.insert_node(3)
lst.insert_node(6)
lst.insert_node(7)
lst.insert_node(9)
lst.insert_node(12)
lst.insert_node(19)
lst.insert_node(17)
lst.insert_node(26)
lst.insert_node(21)
lst.insert_node(25)

node = lst.search_node(12)
print(node)

lst.print_node()
lst.delete_node(12)
lst.print_node()