前言:
本文章主要介绍普通平衡树(splay) 的性质, 结点维护的信息 ,功能,以及具体实现。
上次,已经学习完BST(二叉搜索树)。其实平衡树,是一种平衡二搜索叉树。也就是说,这棵树具有BST , 所有的性质。但是对BST又有很多优化。主要的优化还是效率问题,对树的结构进行旋转。在此不进行证明,其实结果比较明显,将树的深度压在左右!这是控制时间复杂度的重要原因!
平衡树的性质:
- 性质1: 左右子树严格小于或者大于根节点
- 性质2: 对于每个根节点维护size,记录树的大小
- 性质3: 树中不允许出现重复的结点
注: 平衡二叉树的性质 , 就是BST具有的性质
平衡树结点的维护的信息
- 结点本身的信息
- 根结点 (root)
- 总的结点数(idx)
- 插入两个哨兵 , 表示正无穷INF和负无穷-INF(代码中设计的0x3f3f3f3f3f)
struct node {
int p , s[2]; // 父节点 , 左右孩子
int val; // 值
int size , cnt;
void init(int v , int fa){
size = cnt = 1;
p = fa , val = v;
}
}tr[N];
int root , idx;
平衡树的功能
功能1: 旋转子树
- 左旋
- 右旋
为什么左旋和右旋是正确的?
证明右旋:
根据平衡树性质有 , , 将该子树旋转之后仍然满足此关系 , 所以该旋转是正确的。
左旋证明类似 , 不予证明。
结论:左旋、右旋不破坏树的有序性。
将子树旋转之后 , 因为树的结构发生了改变 , 导致根节点的size发生了改变,所以需要自底向上的对受影响的子树结点 进行pushup操作,对size进行重新统计。
这里的 x 结点, y结点就都需要重新统计,因为size信息是通过子树统计上来的。所以需要先统计y子树 , 再统计x子树。
pushup的参考代码:
/**
* @param x 代表当前子树的根节点
* 统计这颗子树的大小 size
*/
void pushup(int x){
tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + tr[x].cnt;
}
旋转参考代码:
/**
* @param x 传入需要旋转的结点
* 假设 x 是左儿子 , 那么右旋
* 假设 x 是右儿子 , 那么左旋
*
* @variety
* y 为初始旋转的根
* z 为 y为根的这颗子树的根节点
* k 代表子树的方向 0 代表左子树 , 1 代表右子树
* k ^ 1 代表与 x 子树方向 相反方向的子树
* 例如 k == 1 代表右子树 , 那么 k ^ 1 指代的左子树
*/
void rotate(int x) {
int y = tr[x].p , z = tr[y].p , k = tr[y].s[1] == x;
// x相反方向的子树b , 成为y的与x方向相同方向子树. 即 b 和 y 建立连接
tr[tr[x].s[k ^ 1]].p = y;
tr[y].s[k] = tr[x].s[k ^ 1];
// x 与 y 建立连接
tr[y].p = x;
tr[x].s[k ^ 1] = y;
// x 与 z 建立连接
tr[z].s[tr[z].s[1] == y] = x;
tr[x].p = z;
// 旋转之后需要 pushup , 因为子树发生改变 , 需要重新统计size
pushup(y) , pushup(x); // 注意受影响的子树 , 自底向上统计 y ---> x
}
注:这里是将左旋和右旋写在一起了 , 请读者理解性记忆!在图上多画几遍,非常重要的一点就是搞清楚各个参数的意思是什么!比如k代表的就一个方向 , 0 代表左子树 、 1 代表右子树。 注释写的很清楚
功能2: 将x结点翻转成k的子节点
翻转需要判断子树是什么类型 , 不同的类型对应不同翻转方法。
- y 是根结点(单旋)
- y 不是根节点(双旋)
- 直线形
- 折线形
- 直线形
注: 以上只列举单边的 单旋 、 双左旋 、双右旋 , 另外一边的单旋、双旋, 也是类似的原理 , 请读者自行加深理解。
参考代码:
/**
* @param x 代表 x 需要旋转的点
* @param k 代表需要旋转到下面的点
*/
void splay(int x , int k){
while(tr[x].p != k){ // 旋转到 k 下面就停止
int y = tr[x].p , z = tr[y].p;
if(z != k) // 说明需要双旋 折转底 , 直转中
(tr[y].s[0] == x) ^ (tr[z].s[0] == y) ? rotate(x) : rotate(y);
rotate(x);
}
if(k == 0) root = x; // 把x旋转到根 , 那么更新根
}
功能3: 将val这个值移到根结点
实现步骤如下:
- 从根结点出发
- 判断val 与 根结点的val值 的关系
- 如果val 小于根结点的话 , 向左搜索
- 如果val 大于根结点的话 , 向右搜索
- 根结点的val 等于 val 的话 , 直接返回
在此讨论一下,搜索可能存在的结果
判断val 与 根结点的val值 的关系
- val 大于根节点
- 树上存在val 找到val所在的位置
- 树上不存在val
- 右子树存在
- 右子树不存在
- 右子树存在
- 树上存在val 找到val所在的位置
- val 小于根节点
- 树上存在val
- 树上不存在val
- 左子树存在
- 左子树不存在
- 左子树存在
- 树上存在val
这样就比较清楚了!
结论:
- 当val存在时 , 必然能找到
- 当val值不存在时, 左/右子树不存在的话。那么就返回根结点 , 如果小于根结点 ,那返回的就是该值的前驱 ,如果大于根结点 ,那么返回的就是该值的后继。
- 当val值不存在时,左/右子树存在。小于根结点向左子树走,找到的结点就是该val的后继 , 大于根结点向右子树走返回的就是该点的前驱
参考代码:
/**
* @param val 待查找的值 , 将该旋转到根结点
*/
void find(int val){
int x = root;
while(tr[x].s[val > tr[x].val] && val != tr[x].val)
x = tr[x].s[val > tr[x].val]; // 向下找
splay(x , 0); // 1、降低树高 2、好获取答案
}
功能4:寻找前驱结点
步骤:
- 将需要寻找前驱的val移到根节点
- 判断根结点的val是否小于该val。小于的话,直接作为前驱返回。
- 如果不小于的话, 那么从左子树开始,不断地向右子树递归
- 直到找到最深的那个结点
- 最后spaly 一下该x 即 前驱结点
参考代码
/**
* @param val 待查找前驱的值
* @return 返回前驱的下标
*/
int get_pre(int val){
find(val);
int x = root;
if(tr[x].val < val) return x;
x = tr[x].s[0]; // 从左子树开始
while(tr[x].s[1]) // 一直向右走
x = tr[x].s[1];
splay(x , 0); // 调整一下树高
return x;
}
功能5:寻找后继结点
步骤:
- 需要寻找后继的val移到根结点 , 使用find函数即可
- 判断根结点的val是否大于该val。如果大于的话 , 直接返回该后继结点即可。
- 如果不大于的val,从右子树开始,不断地向左边搜索
- 直到最深的结点
- 最后spaly 一下该x 即 后继结点
参考代码:
/**
* @param val 待查找后继的点
* @return 返回后继的下标
*/
int get_suc(int val){
find(val);
int x = root;
if(tr[x].val > val) return x;
x = tr[x].s[1]; // 从右子树开始
while(tr[x].s[0]) // 一直向左边走
x = tr[x].s[0];
splay(x , 0);
return x;
}
功能6: 删除结点
步骤:
- 将该val的结点的前驱结点旋转到根结点
- 将该val的后继结点后继结点旋转到根节点的成为右儿子
- 那么我们可以得到val一定是后继结点的左儿子而且这个左儿子一定是叶子结点。
- 拿到val这个结点判断cnt
- 如果大于1的话 , 直接减1即可 , 再splay一下这个需删除的结点
- 如果等于1的话 ,后继结点直接丢掉左儿子即可 ,再splay一下后继结点
-
为什么需要splay 一下呢 ?
- 因为删除这个结点,需要重新统计一下受影响的树的size , 其次就是调整树的深度
-
为什么我们不直接删除这个结点,而是通过将val这个结点旋转到叶子结点呢?
- 答案是:如果直接删除的话,至少需要操作6根指针。更新size更是非常的麻烦!所以将val这个结点旋转到叶子结点是为了更好处理。
-
为什么不特判结点不存在?
- 因为不存在的话 , 无非对空结点进行操作,将后继结点splay ,可以直接cnt = 1的处理方法,所以不需要特判。
参考代码
/**
* @param val 待删除的结点
*/
void del(int val){
int pre = get_pre(val);
int suc = get_suc(val);
splay(pre , 0 ) , splay(suc , pre);
int del = tr[suc].s[0]; // val 结点
// splay 因为删除一个结点 , 所以需要splay 重新统计size 信息
if(tr[del].cnt > 1)
tr[del].cnt -- , splay(del , 0 );
else
tr[suc].s[0] = 0 , splay(suc , 0);
}
功能7: 通过排名查val
步骤:
- 插入结点val(插入结点后面会详细讲 , 会将这个val移到根)
- 记录根结点的左子树的size
- 删除val这个结点
- 返回答案即可
这里作者发现了一个错误的思路,导致被困4小时,疯狂被hack ,题目点此
这组数据
5
1 20
1 24
1 25
3 23
3 23
如果我们使用
find(val);
return tr[tr[root].s[0]].size;
得到的答案是:2 1
实际答案是 : 2 2
这是我找到的被hack的原因,都在图上。
功能8: 通过val查排名
步骤:
- 通过性质判断
- 如果小于 , 则向左搜索
- 如果大于 , 则向右搜索 ,因为统计的size是子树的所以 k -=tr[tr[x].s[0]].size +tr[x].cnt
- 如果等于或者找到空结点,则停止搜索
- splay 一下答案结点x,返回答案x
参考代码:
/**
* @param k 查找排名为k的下标
* @return 返回查找到的值
*/
int get_val(int k){ // 根据排名找值
int x = root;
while(1){
int y = tr[x].s[0];
if(tr[y].size + tr[x].cnt < k){ // 说明在右子树
k -= tr[y].size + tr[x].cnt;
x = tr[x].s[1];
}else
if(tr[y].size >= k) x = tr[x].s[0];
else
break;
}
splay(x , 0);
return tr[x].val;
}
功能9:插入结点
步骤:
- 从根节点出发
- 判断与根结点val的关系,向左子树或者右子树搜索,同时记录该结点的父节点。
- 如果搜索到该结点, 就停下来。
- 判断一下该结点是不是空结点。如果不是空,那么找到树中存在的结点。那么直接tr[x].cnt 加一即可 如果是空结点的话,那么创建该结点。
- 最后splay一下,统计一下树的size
参考代码:
/**
* @param val 需要插入的值
*/
void insert(int val){
int x = root , p = 0;
while(x && tr[x].val != val)
p = x , x = tr[x].s[val > tr[x].val];
if(x)
tr[x].cnt ++;
else{
x = ++ idx; // 建结点
tr[p].s[val > tr[p].val] = x; // 父节点指向子节点
tr[x].init(val , p);
}
splay(x , 0); // 重新统计树的 size
}
可以发现,这棵树代码非常的长!但是功能超全! 这种数据结构成功的将BST 取代了, 如此优雅的一棵树啊!
总的参考代码:
#include <iostream>
using namespace std;
const int N = 100010 , INF = 0x3f3f3f3f;
struct node {
int p , s[2]; // 父节点 , 左右孩子
int val; // 值
int size , cnt;
void init(int v , int fa){
size = cnt = 1;
p = fa , val = v;
}
}tr[N];
int root , idx;
int n;
/**
* @param x 代表当前子树的根节点
* 统计这颗子树的大小 size
*/
void pushup(int x){
tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + tr[x].cnt;
}
/**
* @param x 传入需要旋转的结点
* 假设 x 是左儿子 , 那么右旋
* 假设 x 是右儿子 , 那么左旋
*
* @variety
* \t y 为初始旋转的根\n
* \t z 为 y为根的这颗子树的根节点\n
* \t k 代表子树的方向 0 代表左子树 , 1 代表右子树\n
* \t k ^ 1 代表与 x 子树方向 相反方向的子树 \n 例如 k == 1 代表右子树 , 那么 k ^ 1 指代的左子树
*/
void rotate(int x) {
int y = tr[x].p , z = tr[y].p , k = tr[y].s[1] == x;
// x相反方向的子树b , 成为y的与x方向相同方向子树. 即 b 和 y 建立连接
tr[tr[x].s[k ^ 1]].p = y;
tr[y].s[k] = tr[x].s[k ^ 1];
// x 与 y 建立连接
tr[y].p = x;
tr[x].s[k ^ 1] = y;
// x 与 z 建立连接
tr[z].s[tr[z].s[1] == y] = x;
tr[x].p = z;
// 旋转之后需要 pushup , 因为子树发生改变 , 需要重新统计size
pushup(y) , pushup(x); // 注意受影响的子树 , 自底向上统计 y ---> x
}
/**
*
* @param x 代表 x 需要旋转的点
* @param k 代表需要旋转到下面的点
*/
void splay(int x , int k){
while(tr[x].p != k){ // 旋转到 k 下面就停止
int y = tr[x].p , z = tr[y].p;
if(z != k) // 说明需要双旋 折转底 , 直转中
(tr[y].s[0] == x) ^ (tr[z].s[0] == y) ? rotate(x) : rotate(y);
rotate(x);
}
if(k == 0) root = x; // 把x旋转到根 , 那么更新根
}
/**
* @param val 需要插入的值
*/
void insert(int val){
int x = root , p = 0;
while(x && tr[x].val != val)
p = x , x = tr[x].s[val > tr[x].val];
if(x)
tr[x].cnt ++;
else{
x = ++ idx; // 建结点
tr[p].s[val > tr[p].val] = x; // 父节点指向子节点
tr[x].init(val , p);
}
splay(x , 0); // 重新统计树的 size
}
/**
* @param val 待查找的值 , 将该旋转到根结点
*/
void find(int val){
int x = root;
while(tr[x].s[val > tr[x].val] && val != tr[x].val)
x = tr[x].s[val > tr[x].val]; // 向下找
splay(x , 0); // 1、降低树高 2、好获取答案
}
/**
* @param val 待查找前驱的值
* @return 返回前驱的下标
*/
int get_pre(int val){
find(val);
int x = root;
if(tr[x].val < val) return x;
x = tr[x].s[0]; // 从左子树开始
while(tr[x].s[1]) // 一直向右走
x = tr[x].s[1];
splay(x , 0); // 调整一下树高
return x;
}
/**
* @param val 待查找后继的点
* @return 返回后继的下标
*/
int get_suc(int val){
find(val);
int x = root;
if(tr[x].val > val) return x;
x = tr[x].s[1]; // 从右子树开始
while(tr[x].s[0]) // 一直向左边走
x = tr[x].s[0];
splay(x , 0);
return x;
}
/**
* @param val 待删除的结点
*/
void del(int val){
int pre = get_pre(val);
int suc = get_suc(val);
splay(pre , 0 ) , splay(suc , pre);
int del = tr[suc].s[0]; // val 结点
// splay 因为删除一个结点 , 所以需要splay 重新统计size 信息
if(tr[del].cnt > 1)
tr[del].cnt -- , splay(del , 0 );
else
tr[suc].s[0] = 0 , splay(suc , 0);
}
/**
* @param val 待查找排名的值
* @return 返回该值的排名
*/
int get_rank(int val){ // 获取排名
insert(val);
int res = tr[tr[root].s[0]].size;
del(val);
return res; // 因为还有一个哨兵 , 所以不需要加一
}
/**
* @param k 查找排名为k的下标
* @return 返回查找到的值
*/
int get_val(int k){ // 根据排名找值
int x = root;
while(1){
int y = tr[x].s[0];
if(tr[y].size + tr[x].cnt < k){ // 说明在右子树
k -= tr[y].size + tr[x].cnt;
x = tr[x].s[1];
}else
if(tr[y].size >= k) x = tr[x].s[0];
else
break;
}
splay(x , 0);
return tr[x].val;
}
int main(){
insert(-INF) , insert(INF);
scanf("%d" , &n);
int op , x;
while(n -- ){
scanf("%d%d" , &op , &x);
if(op == 1) insert(x);
else if(op == 2) del(x);
else if(op == 3) printf("%d\n" , get_rank(x));
else if(op == 4) printf("%d\n" , get_val(x + 1));
else if(op == 5) printf("%d\n" , tr[get_pre(x)].val);
else printf("%d\n" , tr[get_suc(x)].val);
}
return 0;
}