跳表

74 阅读1分钟
/* 实现一个简易的跳表 */ 
#ifndef SKIPLIST_H_

#define SKIPLIST_H_

#include <bits/stdc++.h>

struct Node {
    static const uint8_t MAX_LEVEL{32};
    int val;
    std::vector<Node*> forward; // 指针数组
    Node(int val = -1, uint8_t _maxLevel = MAX_LEVEL) : val(val), forward(_maxLevel, nullptr) {}
};

class SkipList {
public:
    SkipList() : level(0), head(new Node()), dis(0, 1) {}

    // 查找
    bool search(int target);

    // 插入
    void add(int number);

    // 删除
    bool erase(int number);

    // 生成随机层数
    uint8_t randomLevel();

private:
    constexpr static uint8_t MAX_LEVEL{32}; // 最大跳表高度
    constexpr static float P{0.25f}; // 跳往下一层的概率
    uint8_t level;   
    Node* head; // 跳表头结点
    std::mt19937 gen{std::random_device{}()};
    std::uniform_real_distribution<double> dis;
};

bool SkipList::search(int target) {
    Node* curr = head;
    for (int i = level - 1; i >= 0; --i) {
        while (curr->forward[i] != nullptr && curr->forward[i]->val < target) {
            curr = curr->forward[i];
        }
    }
    curr = curr->forward[0];
    return curr && curr->val == target;
}

void SkipList::add(int target) {
    Node* curr = head;
    std::vector<Node*> tmp(level, nullptr);
    for (int i = level - 1; i >= 0; --i) {
        while (curr->forward[i] != nullptr && curr->forward[i]->val < target) {
            curr = curr->forward[i];
        }
        tmp[i] = curr;
    }
    
    uint8_t lv = randomLevel(); // 随机层数
    Node* now = new Node(target, lv);

    for (int i = lv - 1; i >= 0; --i) {
        if (i >= level) {
            head->forward[i] = now;
        } else {
            now->forward[i] = tmp[i]->forward[i];
            tmp[i]->forward[i] = now;
        }
    }

    level = std::max(level, lv);
    // std::cout << (int)level << '\n';
}

bool SkipList::erase(int target) {
    Node* curr = head;
    std::vector<Node*> tmp(level);
    for (int i = level - 1; i >= 0; --i) {
        while (curr->forward[i] != nullptr && curr->forward[i]->val < target) {
            curr = curr->forward[i];
        }
        tmp[i] = curr;
    }
    
    curr = curr->forward[0];
    if (curr == nullptr || curr->val != target) return false;
    for (int i = 0; i < level; ++i) {
        if (tmp[i]->forward[i] != curr) break;
        tmp[i]->forward[i] = curr->forward[i];
    }

    delete curr;

    while (level > 0 && head->forward[level - 1] == nullptr) {
        --level;
    }

    std::cout << (int)level << '\n';

    return true;
}

// 随机层数
uint8_t SkipList::randomLevel() {
    uint8_t lv = 1;
    while (dis(gen) < P && lv < MAX_LEVEL) {
        ++lv;
    }
    return lv;
}

#endif // SKIPLIST_H_