【笔记】Treap

291 阅读6分钟

这是我参与11月更文挑战的第24天,活动详情查看:2021最后一次更文挑战.


随机平衡二叉搜索树,又名Treap,树堆. 普通二叉搜索树在读入某些输入时可能会退化成链状,从logn退化成n. 但是注意到如果输入前将所有数据随机重排一遍,BST就会有很好的高度性质. 显然不可能全部读入数据后再随机重排,所以机智的人们发明了treap: 对每个数据节点增加一个值叫做优先级,初始化为一个随机数. 然后要求BST的所有节点的优先级总要比两个子节点的优先级高. 这像什么?是的,一个堆. treap节点的数据值满足二叉检索树的性质,优先级值满足堆的性质,所以称为树堆. (记得做过一个题要求合并后缀前缀相同的两个字符串,tree和heap就能合成为treap)

BST的性质用BST的操作就可以满足,注意对于一组数据,可以构造的bst有相当多种. 怎么找到满足堆性质的那棵树呢,使用旋转操作. 比如下图中,x节点与父节点node的优先级关系不满足,这个时候使用一个右旋操作就可以啦. 百度百科盗来的图

二叉树都是高度递归的结构,treap也不例外. 插入节点:每次递归插入操作后,跟一个检查旋转就可以实现insert. 删除节点:treap有独特的删除方式,只要将待删除的节点不断旋转到叶子处,直接删除就可以了,或者碰到只有一个孩子的那种节点,然后移花接木一波. 查找以及其他的操作都与二叉树类似. 每次旋转操作只需要改变两个子节点的指向(还有节点的一些其他信息),复杂度是O(1). 所以增删改查的复杂度都是O(logn)的.

下面是紧张刺激的实现部分: 代码参考自百度百科,后来发现是__hzwer__大佬的作品,难怪和其他不能看的百度百科不一样.

全局变量:根节点索引root,新节点插入位置id.

根节点好说,只要知道所有的函数都是传引用就行(这条很重要,详见后文) 这个treap是数组实现的,每次插入值都会插入在id处,然后id++,而删除值不会使得id--. 之前的空间实际上没有作用了(当然你也可以手写一个链表内存池)

数组数据:左,右子树索引,节点值,节点优先级,重复数字数,下辖数字数.

前四个都好理解. 重复数字数指的是将所有值相同的数字放在一个节点里,额外插入删除只不过是值++--的问题,超级方便.当初写二叉树怎么没想到这个. 下辖数字个数记录了这个节点及它的孩子们一共包括多少数字,这个数据是实现随机访问的基础.

private函数:void update(int rt),void lturn(int &rt),void rturn(int &rt)

lturn和rturn分别是左旋和右旋,update用来旋转后更新节点的下辖数字数.

public函数:

void insert(int &rt, int num)

插入,不管是初始化还是递归都要记得更改下辖数字数. 没有显式指定父节点的操作,因为递归传引用自动更改了! insert(save[rt].right, num); 真是神奇!

void del(int &rt, int num)

删除, 上文说的差不多了. 学习了一个神奇操作,表示一个节点孩子不满可以判断left*right==0,然后直接让root=left+right. 前提是0表示空树. 注意旋转后传递删除时lturn(rt), del(rt, num);节点传rt!!!因为传引用已经更改rt了!!!!! 引用会改参数!!!! 调了两个小时的bug!!!!

int getrank(int rt, int num)

获得数据排名,相当于lower_bound.分别在左右子树和root的重复数字数之间判断即可.

int getnum(int rt, int rank)

随机访问.分别在左右子树和root的重复数字数之间判断即可.

int getprev(int rt, int num)

求前驱,注意可能是父节点,也可能是子节点,需要用max递归

int getnext(int rt, int num)

同上,需要用到min递归.

/* LittleFall : Hello! */
#include <bits/stdc++.h>
#define ll long long
using namespace std;
inline int read();
inline void write(int x);
const int M = 1000016;

//Treap
int root; //当前根节点
int id; //新节点插入位置
struct Treap
{
    int left;  	//左子树
    int right; 	//右子树
    int value; 	//节点值
    int prior; 	//随机权重,注意堆是最大堆
    int manage;	//下辖数字个数
    int repeat; //本值重复出现的次数
} save[M]; 		//从1计数
//private:
void update(int rt)
{
    save[rt].manage = save[save[rt].left].manage
                      + save[save[rt].right].manage
                      + save[rt].repeat;
}
void lturn(int &rt)
{ 
    //左旋rt,并将nrt置为rt
    int nrt = save[rt].right;
    save[rt].right = save[nrt].left;
    save[nrt].left = rt;
    save[nrt].manage = save[rt].manage;
    update(rt);
    rt = nrt;
}
void rturn(int &rt)
{
    int nrt = save[rt].left;
    save[rt].left = save[nrt].right;
    save[nrt].right = rt;
    save[nrt].manage = save[rt].manage;
    update(rt);
    rt = nrt;
}

//public:
void insert(int &rt, int num)
{
    //rt为当前根节点,类型是引用:返回时可以置子节点.
    if(rt == 0) //插入新叶子
    {
        id++;
        rt = id;
        save[rt].value = num;
        save[rt].manage = 1;
        save[rt].repeat = 1;
        save[rt].prior = rand();
        return;
    }
    save[rt].manage++;
    if(num == save[rt].value) //插入旧叶子
    {
        save[rt].repeat++;
    }
    else if(num > save[rt].value) //插入右子树
    {
        insert(save[rt].right, num);
        if(save[save[rt].right].prior > save[rt].prior)
            lturn(rt);
    }
    else //插入左子树
    {
        insert(save[rt].left, num);
        if(save[save[rt].left].prior > save[rt].prior)
            rturn(rt);
    }
}
//首先像常规BST一样查找删除位置,不考虑查找不到的情况
//然后可以直接将节点旋转到叶子然后直接剥离
//左旋还是右旋取决于左右prior的大小
//需要考虑repeat>1的情况,以及某个节点仅有一个孩子的情况.
void del(int &rt, int num)
{
    if(rt == 0) return;
    if(save[rt].value == num) //find it
    {
        if(save[rt].repeat > 1) //重复节点
            save[rt].repeat--,save[rt].manage--;
        else if(save[rt].left * save[rt].right == 0) //只有一个或没有孩子
            rt = save[rt].left + save[rt].right;
        else if(save[save[rt].left].prior < save[save[rt].right].prior)
            lturn(rt), del(rt, num);
        else
            rturn(rt), del(rt, num);
    }
    else if(num > save[rt].value) //需要在右子树中查找
        save[rt].manage--,del(save[rt].right, num);
    else
        save[rt].manage--,del(save[rt].left, num);
}
//注意排名定义为比当前数小的数字个数+1
int getrank(int rt, int num)
{
    if(rt == 0) return 0; //查找失败,返回一个lower_bound??
    if(num == save[rt].value) //find it
        return save[save[rt].left].manage+1;
    else if(num > save[rt].value) //需要在右子树中查找
        return save[save[rt].left].manage +
               save[rt].repeat + 
               getrank(save[rt].right,num);
	return getrank(save[rt].left,num);
}
int getnum(int rt, int rank)
{
	if(rt==0) return 0; //??
	if(rank<=save[save[rt].left].manage) //落在左子树中
		return getnum(save[rt].left,rank);
	rank-=save[save[rt].left].manage;
	if(rank<=save[rt].repeat) //就是本点
		return save[rt].value;
	//落在右子树中
	return getnum(save[rt].right,rank-save[rt].repeat);
}
int getprev(int rt, int num)
{
	if(rt==0) return INT_MIN; //本区域无解
	if(save[rt].value>=num)
		return getprev(save[rt].left,num);
	return max(save[rt].value,getprev(save[rt].right,num));
}
int getnext(int rt, int num)
{
	if(rt==0) return INT_MAX; 
	if(save[rt].value<=num)
		return getnext(save[rt].right,num);
	return min(save[rt].value,getnext(save[rt].left,num));
}
//unrealized
//int count(int rt,int num); //记数,返回repeat或者0


int main(void)
{
#ifdef _LITTLEFALL_
    freopen("in.txt", "r", stdin);
#endif
    //std::cin.sync_with_stdio(false);

    int n = read();
    for(int i = 0; i < n; ++i)
    {
        int op = read(), x = read();
        switch(op)
        {
        case 1:
            insert(root, x);
            break;
        case 2:
            del(root, x);
            break;
        case 3:
            printf("%d\n", getrank(root, x));
            break;
        case 4:
            printf("%d\n", getnum(root, x));
            break;
        case 5:
            printf("%d\n", getprev(root, x));
            break;
        case 6:
            printf("%d\n", getnext(root, x));
            break;
        }
    }

    return 0;
}


inline int read()
{
    int x = 0, f = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9')
    {
        if(ch == '-')f = -1;
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9')
    {
        x = x * 10 + ch - '0';
        ch = getchar();
    }
    return x * f;
}
inline void write(int x)
{
    if(x < 0) putchar('-'), x = -x;
    if(x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

特别困了emmm先就这样. 为了方便理解所以都是全称,但是考虑到以后要打板子,是不是换成简称好些?

treap加入模板库.已简化


本文也发表于我的 CSDN 博客中。