P3369 【模板】普通平衡树【Splay】

149 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

传送门

分析

您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:

  • 插入 xx
  • 删除 xx 数(若有多个相同的数,因只删除一个)
  • 查询 xx 数的排名(排名定义为比当前数小的数的个数 +1+1 )
  • 查询排名为 xx 的数
  • xx 的前驱(前驱定义为小于 xx,且最大的数)
  • xx 的后继(后继定义为大于 xx,且最小的数)

行!就用Splay了,只要理解,rotate和splay,直接往上套,别管什么常数了

  • 插入,找到xx的前驱和后继,先旋前驱到根,再旋后继到根右子树,在后继的左子树上直接加,判断是否存在,如果存在的话,直接加个数即可
  • 删除,同样操作,找到 xx 的前驱和后继,同样的旋转,在后继的左子树上删除值,判断个数来删除节点
  • 查询排名,二分值找到节点, 旋转到根,输出左子树的 szsz 即可(因为中序遍历有序)
  • 查第几大,在 SplaySplay 上通过 szsz 二分查找即可
  • 找前驱,二分找到节点 xx,往左子树的右子树找值
  • 找后继,二分找到节点xx,往右子树的左子树找值

此时注意) 鉴于找值可能会找不到值,所以只找到自己,或者前驱后继为止(不会到空) 前驱后继那里特判是否直接就找到了前驱或者后继,之际返回即可

代码

//P3369
/*
  @Author: YooQ
*/
#include <bits/stdc++.h>
using namespace std;
#define sc scanf
#define pr printf
#define ll long long
#define FILE_OUT freopen("out", "w", stdout);
#define FILE_IN freopen("in", "r", stdin);
#define debug(x) cout << #x << ": " << x << "\n";
#define AC 0
#define WA 1
#define INF 0x3f3f3f3f
const ll MAX_N = 1e6+5;
const ll MOD = 1e9+7;
int N, M, K;

struct Tr {
	int k, fa, sz, cnt;
	int son[2];
	int& l = son[0];
	int& r = son[1];
}tr[MAX_N];
int indx = 0;
int root = 0;

void push_up(int rt) {
	tr[rt].sz = tr[tr[rt].l].sz + tr[tr[rt].r].sz + tr[rt].cnt;
}

int which(int x) {
	return tr[tr[x].fa].son[1] == x;
}

void rotate(int x) {
	int p = tr[x].fa;
	int q = tr[p].fa;
	int side = which(x);
	
	tr[tr[p].son[side] = tr[x].son[side^1]].fa = p;
	tr[tr[x].son[side^1] = p].fa = x;
	tr[x].fa = q;
	
	if (q) {	
		tr[q].son[tr[q].son[1] == p] = x;
	}
	
	push_up(p);
	push_up(x);
}

void splay(int x, int tar) {
	for (int p; (p = tr[x].fa) != tar; rotate(x)) {
		if (tr[p].fa != tar) rotate(which(x) == which(p) ? p : x);
	}
	if (!tar) root = x;
} 

void insert(int x) {
	int rt = root, p = 0;
	while (rt && tr[rt].k != x) {
		p = rt;
		rt = tr[rt].son[x > tr[rt].k];
	}
	if (rt) {
		tr[rt].cnt++;
		splay(rt, 0);
		return;
	}
	rt = ++indx;
	if (p) {
		tr[p].son[x > tr[p].k] = rt;
	}
	tr[rt].fa = p;
	tr[rt].k = x;
	tr[rt].sz = tr[rt].cnt = 1;
	splay(rt, 0);
}

int find(int x) {
	int rt = root;
	while (tr[rt].son[x > tr[rt].k] && tr[rt].k != x) {
		rt = tr[rt].son[x > tr[rt].k];
	}
	splay(rt, 0);
	return rt;
}

int find_pre(int x) {
	int rt = find(x);
	if (tr[rt].k < x) return rt;
	rt = tr[rt].l;
	while (tr[rt].r) {
		rt = tr[rt].r;
	}
	return rt;
}

int find_nxt(int x) {
	int rt = find(x);
	if (tr[rt].k > x) return rt;
	rt = tr[rt].r;
	while (tr[rt].l) {
		rt = tr[rt].l;
	}
	return rt;
}

void del(int x) {
	int pre = find_pre(x);
	int nxt = find_nxt(x);
	splay(pre, 0);
	splay(nxt, pre);
	int rt = tr[nxt].l;
	if (tr[rt].cnt > 1) {
		--tr[rt].cnt;
		splay(rt, 0);
	} else {
		tr[nxt].l = 0;
	}
}

int query_rk(int x) {
	int rt = find(x);
	return tr[tr[rt].l].sz;
}

int fetch(int x) {
	int rt = root;
	while (x) {
		if (tr[tr[rt].l].sz >= x) {
			rt = tr[rt].l;
			continue;
		}
		x -= tr[tr[rt].l].sz;
		if (x <= tr[rt].cnt) return rt;
		x -= tr[rt].cnt;
		rt = tr[rt].r;
	}
}

void solve(){
	sc("%d", &N);
	insert(-1e9);
	insert(1e9);
	
	int opt, x;
	for (int i = 1; i <= N; ++i) {
		sc("%d%d", &opt, &x);
		if (opt == 1) {
			insert(x);
		} else if (opt == 2) {
			del(x);
		} else if (opt == 3) {
			pr("%d\n", query_rk(x), i);
		} else if (opt == 4) {
			pr("%d\n", tr[fetch(x+1)].k, i);
		} else if (opt == 5) {
			pr("%d\n", tr[find_pre(x)].k, i);
		} else {
			pr("%d\n", tr[find_nxt(x)].k, i);
		}
	}
}

signed main()
{
	#ifndef ONLINE_JUDGE
	//FILE_IN
	FILE_OUT
	#endif
	int T = 1;//cin >> T;
	while (T--) solve();

	return AC;
}