普通平衡树(splay)

147 阅读12分钟

前言:

本文章主要介绍普通平衡树(splay) 的性质, 结点维护的信息 ,功能,以及具体实现。

上次,已经学习完BST(二叉搜索树)。其实平衡树,是一种平衡二搜索叉树。也就是说,这棵树具有BST , 所有的性质。但是对BST又有很多优化。主要的优化还是效率问题,对树的结构进行旋转。在此不进行证明,其实结果比较明显,将树的深度压在lognlogn左右!这是控制时间复杂度的重要原因

平衡树的性质:

  • 性质1: 左右子树严格小于或者大于根节点
  • 性质2: 对于每个根节点维护size,记录树的大小
  • 性质3: 树中不允许出现重复的结点

注: 平衡二叉树的性质 , 就是BST具有的性质

平衡树结点的维护的信息

  • 结点本身的信息
  • 根结点 (root)
  • 总的结点数(idx)
  • 插入两个哨兵 , 表示正无穷INF和负无穷-INF(代码中设计的0x3f3f3f3f3f)

image.png

struct node {

    int p , s[2];               // 父节点 , 左右孩子
    int val;                // 值
    int size , cnt;

    void init(int v , int fa){
        size = cnt = 1;
        p = fa , val = v;
    }
}tr[N];
int root , idx;

平衡树的功能

功能1: 旋转子树

  • 左旋 image.png - 右旋
    image.png

为什么左旋和右旋是正确的?

证明右旋:

根据平衡树性质有 ,  x<yx<b<yx<y<c{\ x < y 且 x < b < y 且 x < y < c } , 将该子树旋转之后仍然满足此关系 , 所以该旋转是正确的。

左旋证明类似 , 不予证明。

结论:左旋、右旋不破坏树的有序性

将子树旋转之后 , 因为树的结构发生了改变 , 导致根节点的size发生了改变,所以需要自底向上对受影响的子树结点 进行pushup操作,对size进行重新统计。

image.png

这里的 x 结点, y结点就都需要重新统计,因为size信息是通过子树统计上来的。所以需要先统计y子树 , 再统计x子树

pushup的参考代码:

/**
* @param x  代表当前子树的根节点
* 统计这颗子树的大小 size
*/
void pushup(int x){
  tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + tr[x].cnt;
}

旋转参考代码:

/**
* @param x  传入需要旋转的结点
* 假设 x 是左儿子 , 那么右旋
* 假设 x 是右儿子 , 那么左旋
*
* @variety
* y 为初始旋转的根
* z 为 y为根的这颗子树的根节点
* k 代表子树的方向 0 代表左子树 , 1 代表右子树
* k ^ 1 代表与 x 子树方向 相反方向的子树
* 例如 k == 1 代表右子树 , 那么 k ^ 1 指代的左子树
*/
void rotate(int x) {

  int y = tr[x].p , z = tr[y].p , k = tr[y].s[1] == x;

  // x相反方向的子树b , 成为y的与x方向相同方向子树.      即 b 和 y 建立连接
  tr[tr[x].s[k ^ 1]].p = y;
  tr[y].s[k] = tr[x].s[k ^ 1];

  //  x 与 y 建立连接
  tr[y].p = x;
  tr[x].s[k ^ 1] = y;

  // x 与 z 建立连接
  tr[z].s[tr[z].s[1] == y] = x;
  tr[x].p = z;

  // 旋转之后需要 pushup , 因为子树发生改变 , 需要重新统计size

  pushup(y) , pushup(x);                  // 注意受影响的子树 , 自底向上统计 y ---> x
}

注:这里是将左旋和右旋写在一起了 , 请读者理解性记忆!在图上多画几遍,非常重要的一点就是搞清楚各个参数的意思是什么!比如k代表的就一个方向 , 0 代表左子树 、 1 代表右子树。 注释写的很清楚

功能2: 将x结点翻转成k的子节点

翻转需要判断子树是什么类型 , 不同的类型对应不同翻转方法。

  • y 是根结点(单旋)

image.png

  • y 不是根节点(双旋)
    • 直线形 image.png
    • 折线形

image.png

注: 以上只列举单边的 单旋 、 双左旋 、双右旋 , 另外一边的单旋、双旋, 也是类似的原理 , 请读者自行加深理解。

参考代码:

/**
 * @param x   代表 x 需要旋转的点
 * @param k   代表需要旋转到下面的点
 */

void splay(int x , int k){

    while(tr[x].p != k){        // 旋转到 k 下面就停止
        int y = tr[x].p ,  z = tr[y].p;
        if(z != k)                          // 说明需要双旋     折转底 , 直转中
            (tr[y].s[0] == x) ^ (tr[z].s[0] == y) ? rotate(x) : rotate(y);
        rotate(x);
    }

    if(k == 0) root = x;                // 把x旋转到根 , 那么更新根
}

功能3: 将val这个值移到根结点

实现步骤如下:

  • 从根结点出发
  • 判断val 与 根结点的val值 的关系
    • 如果val 小于根结点的话 , 向左搜索
    • 如果val 大于根结点的话 , 向右搜索
    • 根结点的val 等于 val 的话 , 直接返回

在此讨论一下,搜索可能存在的结果

判断val 与 根结点的val值 的关系
  • val 大于根节点
    • 树上存在val 找到val所在的位置 image.png
    • 树上不存在val
      • 右子树存在 image.png
      • 右子树不存在

image.png

  • val 小于根节点
    • 树上存在val image.png
    • 树上不存在val
      • 左子树存在 image.png
      • 左子树不存在
        image.png

这样就比较清楚了!

结论:
  1. 当val存在时 , 必然能找到
  2. 当val值不存在时, 左/右子树不存在的话。那么就返回根结点 , 如果小于根结点 ,那返回的就是该值的前驱 ,如果大于根结点 ,那么返回的就是该值的后继
  3. 当val值不存在时,左/右子树存在小于根结点左子树走,找到的结点就是该val的后继 , 大于根结点右子树走返回的就是该点的前驱

参考代码:

/**
 * @param val  待查找的值 , 将该旋转到根结点
 */

void find(int val){

    int x = root;
    while(tr[x].s[val > tr[x].val] && val != tr[x].val)
        x = tr[x].s[val > tr[x].val];               // 向下找

    splay(x , 0);               // 1、降低树高  2、好获取答案
}

功能4:寻找前驱结点

步骤:

  1. 将需要寻找前驱的val移到根节点
  2. 判断根结点的val是否小于该val。小于的话,直接作为前驱返回。
  3. 如果不小于的话, 那么从左子树开始,不断地向右子树递归
  4. 直到找到最深的那个结点
  5. 最后spaly 一下该x 即 前驱结点

参考代码

/**
 * @param val   待查找前驱的值
 * @return      返回前驱的下标
 */

int get_pre(int val){

    find(val);

    int x = root;
    if(tr[x].val < val) return x;

    x = tr[x].s[0];             // 从左子树开始
    while(tr[x].s[1])           // 一直向右走
        x = tr[x].s[1];

    splay(x , 0);               // 调整一下树高

    return x;
}

功能5:寻找后继结点

步骤:

  1. 需要寻找后继的val移到根结点 , 使用find函数即可
  2. 判断根结点的val是否大于该val。如果大于的话 , 直接返回该后继结点即可。
  3. 如果不大于的val,从右子树开始,不断地向左边搜索
  4. 直到最深的结点
  5. 最后spaly 一下该x 即 后继结点

参考代码:

/**
 * @param val 待查找后继的点
 * @return  返回后继的下标
 */

int get_suc(int val){

    find(val);

    int x = root;
    if(tr[x].val > val) return x;

    x = tr[x].s[1];             // 从右子树开始
    while(tr[x].s[0])           // 一直向左边走
        x = tr[x].s[0];

    splay(x , 0);

    return x;
}

功能6: 删除结点

步骤:

  1. 将该val的结点的前驱结点旋转到根结点
  2. 将该val的后继结点后继结点旋转到根节点的成为右儿子
  3. 那么我们可以得到val一定是后继结点的左儿子而且这个左儿子一定是叶子结点。
  4. 拿到val这个结点判断cnt
    • 如果大于1的话 , 直接减1即可 , 再splay一下这个需删除的结点
    • 如果等于1的话 ,后继结点直接丢掉左儿子即可 ,再splay一下后继结点
  • 为什么需要splay 一下呢 ?

    • 因为删除这个结点,需要重新统计一下受影响的树的size , 其次就是调整树的深度
  • 为什么我们不直接删除这个结点,而是通过将val这个结点旋转到叶子结点呢?

    • 答案是:如果直接删除的话,至少需要操作6根指针。更新size更是非常的麻烦!所以将val这个结点旋转到叶子结点是为了更好处理。
  • 为什么不特判结点不存在?

    • 因为不存在的话 , 无非对空结点进行操作,将后继结点splay ,可以直接cnt = 1的处理方法,所以不需要特判。

参考代码

/**
 * @param val 待删除的结点
 */
void del(int val){

    int pre = get_pre(val);
    int suc = get_suc(val);

    splay(pre , 0 ) , splay(suc , pre);

    int del = tr[suc].s[0];             // val 结点

    // splay 因为删除一个结点 , 所以需要splay 重新统计size 信息
    if(tr[del].cnt > 1)
        tr[del].cnt -- , splay(del , 0 );
    else
        tr[suc].s[0] = 0 , splay(suc , 0);
}

功能7: 通过排名查val

步骤:

  1. 插入结点val(插入结点后面会详细讲 , 会将这个val移到根)
  2. 记录根结点的左子树的size
  3. 删除val这个结点
  4. 返回答案即可

这里作者发现了一个错误的思路,导致被困4小时,疯狂被hack ,题目点此

这组数据

5
1 20
1 24
1 25
3 23
3 23

如果我们使用

find(val);
return tr[tr[root].s[0]].size;

得到的答案是:2 1

实际答案是 : 2 2

image.png

这是我找到的被hack的原因,都在图上。

功能8: 通过val查排名

步骤:

  1. 通过性质判断
  2. 如果小于 , 则向左搜索
  3. 如果大于 , 则向右搜索 ,因为统计的size是子树的所以 k -=tr[tr[x].s[0]].size +tr[x].cnt
  4. 如果等于或者找到空结点,则停止搜索
  5. splay 一下答案结点x,返回答案x

参考代码:

/**
 * @param k   查找排名为k的下标
 * @return   返回查找到的值
 */
int get_val(int k){                 // 根据排名找值

    int x = root;

    while(1){
        int y = tr[x].s[0];
        if(tr[y].size + tr[x].cnt < k){     // 说明在右子树
            k -=  tr[y].size + tr[x].cnt;
            x = tr[x].s[1];
        }else
            if(tr[y].size >= k) x = tr[x].s[0];
            else
                break;
    }

    splay(x , 0);

    return tr[x].val;
}

功能9:插入结点

步骤:

  1. 从根节点出发
  2. 判断与根结点val的关系,向左子树或者右子树搜索,同时记录该结点的父节点。
  3. 如果搜索到该结点, 就停下来。
  4. 判断一下该结点是不是空结点。如果不是空,那么找到树中存在的结点。那么直接tr[x].cnt 加一即可 如果是空结点的话,那么创建该结点。
  5. 最后splay一下,统计一下树的size

参考代码:

/**
 * @param val 需要插入的值
 */

void insert(int val){

    int x  = root , p = 0;

    while(x && tr[x].val != val)
        p = x , x = tr[x].s[val > tr[x].val];

    if(x)
        tr[x].cnt ++;
    else{
        x = ++ idx;              // 建结点
        tr[p].s[val > tr[p].val] = x;                   // 父节点指向子节点
        tr[x].init(val , p);
    }

    splay(x , 0);                   // 重新统计树的 size
}

可以发现,这棵树代码非常的长!但是功能超全! 这种数据结构成功的将BST 取代了, 如此优雅的一棵树啊!

image.png

总的参考代码:

#include <iostream>

using namespace std;

const int N = 100010 , INF = 0x3f3f3f3f;

struct node {

    int p , s[2];               // 父节点 , 左右孩子
    int val;                // 值
    int size , cnt;

    void init(int v , int fa){
        size = cnt = 1;
        p = fa , val = v;
    }
}tr[N];

int root , idx;
int n;

/**
 * @param x  代表当前子树的根节点
 * 统计这颗子树的大小 size
 */
void pushup(int x){
    tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + tr[x].cnt;
}
/**
 * @param x  传入需要旋转的结点
 * 假设 x 是左儿子 , 那么右旋
 * 假设 x 是右儿子 , 那么左旋
 *
 * @variety
 * \t y 为初始旋转的根\n
 * \t z 为 y为根的这颗子树的根节点\n
 * \t k 代表子树的方向 0 代表左子树 , 1 代表右子树\n
 * \t k ^ 1 代表与 x 子树方向 相反方向的子树 \n  例如 k == 1 代表右子树 , 那么 k ^ 1 指代的左子树
 */
void rotate(int x) {

    int y = tr[x].p , z = tr[y].p , k = tr[y].s[1] == x;

    // x相反方向的子树b , 成为y的与x方向相同方向子树.      即 b 和 y 建立连接
    tr[tr[x].s[k ^ 1]].p = y;
    tr[y].s[k] = tr[x].s[k ^ 1];

    //  x 与 y 建立连接
    tr[y].p = x;
    tr[x].s[k ^ 1] = y;

    // x 与 z 建立连接
    tr[z].s[tr[z].s[1] == y] = x;
    tr[x].p = z;

    // 旋转之后需要 pushup , 因为子树发生改变 , 需要重新统计size

    pushup(y) , pushup(x);                  // 注意受影响的子树 , 自底向上统计 y ---> x
}

/**
 *
 * @param x   代表 x 需要旋转的点
 * @param k   代表需要旋转到下面的点
 */

void splay(int x , int k){

    while(tr[x].p != k){        // 旋转到 k 下面就停止
        int y = tr[x].p ,  z = tr[y].p;
        if(z != k)                          // 说明需要双旋     折转底 , 直转中
            (tr[y].s[0] == x) ^ (tr[z].s[0] == y) ? rotate(x) : rotate(y);
        rotate(x);
    }

    if(k == 0) root = x;                // 把x旋转到根 , 那么更新根
}

/**
 * @param val 需要插入的值
 */

void insert(int val){

    int x  = root , p = 0;

    while(x && tr[x].val != val)
        p = x , x = tr[x].s[val > tr[x].val];

    if(x)
        tr[x].cnt ++;
    else{
        x = ++ idx;              // 建结点
        tr[p].s[val > tr[p].val] = x;                   // 父节点指向子节点
        tr[x].init(val , p);
    }

    splay(x , 0);                   // 重新统计树的 size
}

/**
 * @param val  待查找的值 , 将该旋转到根结点
 */

void find(int val){

    int x = root;
    while(tr[x].s[val > tr[x].val] && val != tr[x].val)
        x = tr[x].s[val > tr[x].val];               // 向下找

    splay(x , 0);               // 1、降低树高  2、好获取答案
}

/**
 * @param val   待查找前驱的值
 * @return      返回前驱的下标
 */

int get_pre(int val){

    find(val);

    int x = root;
    if(tr[x].val < val) return x;

    x = tr[x].s[0];             // 从左子树开始
    while(tr[x].s[1])           // 一直向右走
        x = tr[x].s[1];

    splay(x , 0);               // 调整一下树高

    return x;
}

/**
 * @param val 待查找后继的点
 * @return  返回后继的下标
 */

int get_suc(int val){

    find(val);

    int x = root;
    if(tr[x].val > val) return x;

    x = tr[x].s[1];             // 从右子树开始
    while(tr[x].s[0])           // 一直向左边走
        x = tr[x].s[0];

    splay(x , 0);

    return x;
}


/**
 * @param val 待删除的结点
 */
void del(int val){

    int pre = get_pre(val);
    int suc = get_suc(val);

    splay(pre , 0 ) , splay(suc , pre);

    int del = tr[suc].s[0];             // val 结点

    // splay 因为删除一个结点 , 所以需要splay 重新统计size 信息
    if(tr[del].cnt > 1)
        tr[del].cnt -- , splay(del , 0 );
    else
        tr[suc].s[0] = 0 , splay(suc , 0);
}


/**
 * @param val  待查找排名的值
 * @return  返回该值的排名
 */
int get_rank(int val){         // 获取排名

    insert(val);

    int res = tr[tr[root].s[0]].size;

    del(val);

    return res;                 // 因为还有一个哨兵 , 所以不需要加一
}

/**
 * @param k   查找排名为k的下标
 * @return   返回查找到的值
 */
int get_val(int k){                 // 根据排名找值

    int x = root;

    while(1){
        int y = tr[x].s[0];
        if(tr[y].size + tr[x].cnt < k){     // 说明在右子树
            k -=  tr[y].size + tr[x].cnt;
            x = tr[x].s[1];
        }else
            if(tr[y].size >= k) x = tr[x].s[0];
            else
                break;
    }

    splay(x , 0);

    return tr[x].val;
}



int main(){

    insert(-INF) , insert(INF);

    scanf("%d" , &n);

    int op , x;
    while(n -- ){
        scanf("%d%d" , &op , &x);

        if(op == 1) insert(x);
        else if(op == 2) del(x);
        else if(op == 3) printf("%d\n" , get_rank(x));
        else if(op == 4) printf("%d\n" , get_val(x + 1));
        else if(op == 5) printf("%d\n" , tr[get_pre(x)].val);
        else printf("%d\n" , tr[get_suc(x)].val);
    }

    return 0;
}