普通BST模板

86 阅读7分钟

普通BST(Binary Search Tree)

前言: 本文章主要介绍的BST 的功能、性质、如何实现、以及优点和缺点。 二叉搜索树是一种树形的数据结构,后面有很多根据此树的数据结构。 例如 "平衡二叉树" , "文艺平衡树" , "splay" 等等,BST 对于后面的学习打下重要基础。

二叉搜索树,又称BST。此树没有所谓升序构造和降序构造。 因为不过就是左右子树的性质互换,没有本质区别。

1. 该数据结构的功能

主要可以完成的功能包括:

  • 查询序列中排名为x的结点
  • 查询序列中x数的排名
  • 查询x数的前序
  • 查询x数的后继
  • 增加结点
  • 删除结点

注:快选算法虽然可以做到O(n)的时间复杂度完成功能1,但是在数据分布随机情况下BST可以做到O(logn) 但是就极端情况BST有其致命缺点!

2. BST 维护的性质

BST树维护了以下的性质:

  • 对于左子树val值一定严格小于根结点的val
  • 对于右子树val值一定严格大于根结点的val
  • 对于每个结点维护cnt , 为了维护性质1、2树上不允许存在重复val值的结点,所以维护cnt记录当前结点的数量
  • 对于每个结点维护size值,size定义:左右子树大小及其根结点的数量(即cnt数)
  • 对于每个结点维护左右子树指针(一般存储下标)

3. 数据存储形式

主要形式:结构体

以下是存储属性:

image.png

4. 功能的实现

1、增添结点

参数: x 当前所在结点的指针 val 序列中需要插入的元素

  1. 从根节点出发 , 递归到 x 结点维护size + 1 (因为到当前结点,插入结点肯定会这棵树上).
  2. 判断当前的根结点是不是需要添加的元素 , 若匹配 , 当前结点的元素个数 cnt + 1
  3. 如果不是的话(即未匹配上) , 那么判断该元素的val 与 需要插入的val的关系。
    • 如果 val < tr[x].val , 那么递归到左子树
      • 如果左子树不存在 , 那么直接创建一个结点 , 成为根结点的左儿子
      • 如果左子树存在 , 那么向左子树递归搜索
    • 如果 val > tr[x].val , 那么递归到右子树
      • 同理 , 如果右子树不存在 , 那么直接创建一个结点 , 成为根结点的右儿子
      • 如果右子树存在 , 那么向右子树递归搜索

image.png

void add(int x , int val){

    tr[x].size ++;

    if(tr[x].val == val) {
        tr[x].cnt ++;
        return ;
    }

    if(val < tr[x].val)
        if(tr[x].ls)
            add(tr[x].ls , val);
        else{
            ++cnt;
            tr[cnt].val = val;
            tr[cnt].size = tr[cnt].cnt = 1;
            tr[x].ls = cnt;
        }
    else
        if(tr[x].rs)
            add(tr[x].rs , val);
        else{
            ++cnt;
            tr[cnt].val = val;
            tr[cnt].size = tr[cnt].cnt = 1;
            tr[x].rs = cnt;
        }
}

2、查询前驱

参数: x 当前所在结点的指针 , val 需要查询的数 , ans* 当前找到的前驱

  1. 从根节点开始 , 需要特判一下前驱结点 , 为空树的话 直接返回ans.
  2. 如果不是空树的话 , 分情况讨论
    • val <= t[x].val ------------------(<= 有讲究)
      • 如果左子树不存在的话 , 那只能返回找到已经得到的 ans
      • 如果左子树存在的话 , 那就向左子树继续寻找答案 , 但是根结点无法更新ans
    • val > t[x].val
      • 如果右子树不存在的话 , 还需要判断一下根结点是否被删除(即cnt != 0)
        • 如果存在的话 , 返回t[x].val 作为答案
        • 如果不存在的话 ,返回ans 作为答案.
      • 如果右子树存在的话 , 还需要判断根是否被删除
        • 如果根结点被删除掉的话 , 向右子树递归但是无法更新 ans
        • 如果根结点没有被删除的话 , 向右子树递归 , 将 ans 更新成 tr[x].val
int query_fr(int x , int val , int ans){

    if(cnt == 0) return  ans;               // 保证根结点一定存在 , 不存在的话 , 直接返回 INF

    if(val <= tr[x].val)
        if(tr[x].ls == 0) return ans;
        else
            return query_fr(tr[x].ls , val , ans);
    else{
        if(tr[x].rs == 0)
            if(tr[x].cnt !=0)                       // 判断一下根结点是否被删除掉
                return tr[x].val;
            else
                return ans;

        if(tr[x].cnt != 0) return query_fr(tr[x].rs , val , tr[x].val);
        else
            return query_fr(tr[x].rs , val , ans);
    }

}

3、查询后继

参数: x 当前所在结点的指针 , val 需要查询的数 , ans 当前找到的后继

  1. 从根节点开始 , 并不需要特判后继 , 全局变量自动初始化后继(即 0)
  2. 利用该树的性质 , 同样地对 val 和 tr[x].val 分情况讨论
    • val >= tr[x].val ------------------(>= 有讲究)
      • 如果右子树不存在的话 , 那只能返回 ans 作为答案
      • 如果右子树存在的话 , 那么就向右子树递归 , 但是无法更新 ans
    • val < tr[x].val
      • 如果左子树不存在的话 , 判断一下根节点是否被删除
        • 如果根结点未被删除 , 那么返回tr[x].val
        • 如果根结点被删除 , 那么返回 ans
      • 如果左子树存在的话 , 判断一下根节点是否被删除
        • 如果根结点未被删除 , 那么更新 ans 为 tr[x].val 向左子树递归
        • 如果根结点被删除 , 那么不能更新 ans , 向左子树递归

image.png

int query_be(int x , int val , int ans){

    if(val >= tr[x].val)
        if(tr[x].rs == 0) return ans;
        else
            return query_be(tr[x].rs , val , ans);
    else{

        if(tr[x].ls == 0)
            if(tr[x].cnt != 0)
                return tr[x].val;
            else
                return ans;

        if(tr[x].cnt != 0) return query_be(tr[x].ls , val , tr[x].val);
        else
            return query_be(tr[x].ls , val , ans);
    }
}

4、查询x数的排名

参数: x 当前所在的结点 , val 查询的x数

  1. 从根结点开始
    • 先判断这个结点是否存在 , 不存在的话 , 直接返回 0.
    • 对于特殊情况 , 空树的话 , 最后还是会返回0 , 因为无论是向左还是向右都是空树。
  2. val 是否与 tr[x].val(根结点)匹配 , 如果匹配的话 , 直接返回左边结点数(即 tr[x].val );
  3. val < tr[x].val , 那么向左子树递归
  4. val > tr[x].val , 那么向右子树递归 , 需要注意一下 , 右边结点的排名计算
size=tr[tr[x].ls].size+tr[x].cnt+returnquery(tr[x].rs,val); size = tr[ tr[x].ls ].size + tr[x].cnt + _{return} query(tr[x].rs , val);

注意:结果一定要加 1 , 函数返回的结果只是大于多少数, 需要 +1+ 1 才是该数的真实排名

image.png

int query_val(int x , int rank){

    if(x == 0) return INF;

    if(tr[tr[x].ls].size >= rank)
        return query_val(tr[x].ls , rank);
    else if(tr[tr[x].ls].size + tr[x].cnt >= rank)
        return tr[x].val;
    else
        return query_val(tr[x].rs , rank - tr[tr[x].ls].size - tr[x].cnt);
}

5、查询排名为x的数

参数: x当前所在的结点 , rank 查询的排名x

  1. 从根节点开始
    • 判断这个结点是否存在如果不存在的话 , 直接返回 INF
    • 对应特殊情况 , 即空树的情况 , 因为左右子树都是空 , 所以无论走哪边都是返回INF
  2. 如果左子树的size 是否大于等于 rank , 那么说明在左子树 , 向左边递归
  3. 如果左子树的size + 根结点的cnt , 那么说明在根结点 , 那么直接返回根节点的值 tr[x].val
  4. 以上情况都不符合的话 , 那么就是在右子树。 但是要注意 , 参数rank 在向右子树递归的要做处理!
rank=ranktr[tr[x].ls].sizetr[x].cnt {rank = rank - tr[tr[x].ls].size - tr[x].cnt}

因为对于每棵树(整棵树和子树) , 都是维护的局部的size , 所以一定要做rank的变换

int query_rank(int x , int val){

    if(x == 0) return 0;

    if(val == tr[x].val) return tr[tr[x].ls].size;
    else if(val < tr[x].val)
        return query_rank(tr[x].ls , val);
    else
        return query_rank(tr[x].rs , val) + tr[tr[x].ls].size + tr[x].cnt;
}

6、删除某个结点

参数:x 当前所在的结点 , val

  1. 第一步特判结点是否存在 , 不存在直接返回
  2. 如果这个结点的val 和 需要寻找的val 都匹配的话 , 那么需要判断一下该结点是否已经删除完 , 若该结点没有被删除完的话 , tr[x].cnt -- 即删除一个结点。 直接返回即可。
  3. 如果val < tr[x].val , 那么需要向左子树递归
  4. 如果val > tr[x].val , 那么需要向右子树递归
void delet(int x , int val){

    if(x == 0) return ;

    if(val == tr[x].val){
        if(tr[x].cnt != 0)
            tr[x].cnt --;
        return ;
    }
    else if(val < tr[x].val)
        delet(tr[x].ls , val);
    else
        delet(tr[x].rs , val);
}

该数据结构虽然有很多的好的功能,但是对于一些极端情况该树的时间复杂度依旧会非常高!达到了O(n2)O(n^2) , 所以想需要性能更优的数据结构还得需要加深学习

参考代码:

#include <iostream>

using namespace std;

const int N = 10010 , INF = 0x7fffffff;

int n , cnt;
struct node{
    int val;
    int ls , rs;
    int size , cnt;
}tr[N];


void add(int x , int val){

    tr[x].size ++;

    if(tr[x].val == val) {
        tr[x].cnt ++;
        return ;
    }

    if(val < tr[x].val)
        if(tr[x].ls)
            add(tr[x].ls , val);
        else{
            ++cnt;
            tr[cnt].val = val;
            tr[cnt].size = tr[cnt].cnt = 1;
            tr[x].ls = cnt;
        }
    else
        if(tr[x].rs)
            add(tr[x].rs , val);
        else{
            ++cnt;
            tr[cnt].val = val;
            tr[cnt].size = tr[cnt].cnt = 1;
            tr[x].rs = cnt;
        }
}

int query_fr(int x , int val , int ans){

    if(cnt == 0) return  ans;               // 保证根结点一定存在 , 不存在的话 , 直接返回 INF

    if(val <= tr[x].val)
        if(tr[x].ls == 0) return ans;
        else
            return query_fr(tr[x].ls , val , ans);
    else{
        if(tr[x].rs == 0)
            if(tr[x].cnt !=0)                       // 判断一下根结点是否被删除掉
                return tr[x].val;
            else
                return ans;

        if(tr[x].cnt != 0) return query_fr(tr[x].rs , val , tr[x].val);
        else
            return query_fr(tr[x].rs , val , ans);
    }

}

int query_be(int x , int val , int ans){

    if(val >= tr[x].val)
        if(tr[x].rs == 0) return ans;
        else
            return query_be(tr[x].rs , val , ans);
    else{

        if(tr[x].ls == 0)
            if(tr[x].cnt != 0)
                return tr[x].val;
            else
                return ans;

        if(tr[x].cnt != 0) return query_be(tr[x].ls , val , tr[x].val);
        else
            return query_be(tr[x].ls , val , ans);
    }
}

int query_val(int x , int rank){

    if(x == 0) return INF;

    if(tr[tr[x].ls].size >= rank)
        return query_val(tr[x].ls , rank);
    else if(tr[tr[x].ls].size + tr[x].cnt >= rank)
        return tr[x].val;
    else
        return query_val(tr[x].rs , rank - tr[tr[x].ls].size - tr[x].cnt);
}

int query_rank(int x , int val){

    if(x == 0) return 0;

    if(val == tr[x].val) return tr[tr[x].ls].size;
    else if(val < tr[x].val)
        return query_rank(tr[x].ls , val);
    else
        return query_rank(tr[x].rs , val) + tr[tr[x].ls].size + tr[x].cnt;
}

void delet(int x , int val){

    if(x == 0) return ;

    if(val == tr[x].val){
        if(tr[x].cnt != 0)
            tr[x].cnt --;
        return ;
    }
    else if(val < tr[x].val)
        delet(tr[x].ls , val);
    else
        delet(tr[x].rs , val);
}


int main(){

    scanf("%d" , &n);

    int x  ,val;
    while (n -- ){

        scanf("%d%d" , &x , &val);
        if(x == 1) printf("%d\n" , query_rank(1  , val) + 1);
        if(x == 2) printf("%d\n" , query_val(1 , val));
        if(x == 3) printf("%d\n" , query_fr(1, val , -INF));
        if(x == 4) printf("%d\n" , query_be(1 , val , INF));
        if(x == 5)
            if(cnt == 0){
                ++ cnt;
                tr[cnt].cnt= tr[cnt].size = 1;
                tr[cnt].val = val;
            }else
                add(1 , val);

    }

    return 0;
}