平衡树

简单写一下 splay。

刚刚写了一下没写出除了少 splay 以外的任何锅来。令人感慨。

所有操作都是基于朴素二叉搜索树,只是加入了一个 splay 操作均摊复杂度。

\(splay(u,x)\) 表示把 \(u\) 旋转到 \(x\) 的儿子位置(如果 \(x=0\) 则转到根)。

旋转采用双旋保证复杂度。对于一个点 \(x\) 和他的父亲 \(y\)\(y\) 的父亲 \(z\),如果三者在树上的儿子关系相同,那么先转 \(y\) 再转 \(x\),否则转两次 \(x\)

旋转的方式手玩一下就能玩出来,保证二叉搜索树的性质即可。

然后注意不管干什么操作,如果找到了某个点,都要在最后 \(splay\) 这个点到根以保证复杂度。

code:

#include<bits/stdc++.h>
//#include <ext/pb_ds/hash_policy.hpp>
//#include <ext/pb_ds/assoc_container.hpp>
namespace infinities{
	#define fint register int
	#define ls(i) (i << 1)
	#define rs(i) ((i << 1) | 1)
	#define pii pair<int, int>
	#define im int mid = (l + r) >> 1
	#define INT __int128
	#define ll long long
	#define ui unsigned int
	#define ull unsigned long long
	#define lc ch[now][0]
	#define rc ch[now][1]
	const int mod = 998244353, INF = 1e9 + 6e8 + 7, maxn = 2e6 + 2e5 + 7;
	namespace FastIO{//10M
		const int SS = 1e7; char num_[50]; int cnt_;
		inline int xchar(){static char buf[SS]; static int len = 0, pos = 0; if(pos == len)pos = 0, len = fread(buf, 1, SS, stdin); if(pos == len)exit(0); return buf[pos++];}
		inline int read(){fint x = 0, f = 1, ch = getchar(); while(ch < '0' || ch > '9'){if(ch == '-')f = -1; ch = getchar();}while(ch >= '0' && ch <= '9'){x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();} return x * f;}
		inline void write(int x){if(x == 0){putchar('0'); return;}if(x < 0)putchar('-'), x = -x; while(x)num_[++cnt_] = x % 10 ^ 48,x /= 10; while(cnt_)putchar(num_[cnt_--]);}
		inline void write(int x, char c){write(x), putchar(c);}
	}
	//using namespace __gnu_pbds;
	//gp_hash_table<int, int>hsh;
	void file(){freopen("P6136_5.in", "r", stdin);}
	using namespace std;
	using namespace FastIO;
	namespace mystd{
		inline void chkmin(int &a, int b){a = min(a, b); return;} inline void chkmax(int &a, int b){a = max(a, b); return;}
		inline void qk(int &a, int b){a = a + b; a = (a >= mod) ? a - mod : a;}
		inline void sub(int &a, int b){a = a - b + mod; a = (a >= mod) ? a - mod : a;}
		inline int qpow(int x, int k){fint res = 1; for(; k; k = (k >> 1), x = 1ll * x * x % mod)if(k & 1)res = 1ll * res * x % mod; return res;}
	}
	using namespace mystd;
	int n, ch[maxn][2], size[maxn], cnt[maxn], fa[maxn], rt, val[maxn], tot;
	int newnode(){++tot; size[tot] = ch[tot][0] = ch[tot][1] = fa[tot] = cnt[tot] = val[tot] = 0; return tot;}
	void pu(int now){size[now] = size[ch[now][0]] + size[ch[now][1]] + cnt[now];}
	void rotate(int x){
		fint y = fa[x], z = fa[y];
		bool fl = (ch[y][1] == x);
		ch[y][fl] = ch[x][fl ^ 1];
		ch[x][fl ^ 1] = y, fa[ch[y][fl]] = y, fa[ch[x][fl ^ 1]] = x;
		fa[x] = z, ch[z][ch[z][1] == y] = x;
		pu(y), pu(x);
	}
	void splay(int x, int v){
		while(fa[x] != v){
			fint y = fa[x], z = fa[y];
			if(z == v){rotate(x); break;}
			((ch[z][1] == y) == (ch[y][1] == x)) ? rotate(y) : rotate(x);
			rotate(x);
		}
		if(v == 0)rt = x;
	}
	bool find(int x){
		if(!rt)return 0;
		fint now = rt;
		while(val[now] != x && ch[now][x > val[now]])now = ch[now][x > val[now]];
		splay(now, 0);
		if(val[now] == x)return 1; return 0;
	}
	void insert(int x){
		if(!rt){rt = newnode(); cnt[rt] = size[rt] = 1, val[rt] = x; return;}
		if(find(x)){++cnt[rt]; pu(rt); return;}
		fint now = rt;
		while(ch[now][x > val[now]])now = ch[now][x > val[now]];
		ch[now][x > val[now]] = newnode();
		fa[tot] = now, val[tot] = x, cnt[tot] = 1, size[tot] = 1;
		splay(tot, 0);
	}
	int next(int x, int op){
		find(x);
		if(val[rt] < x && op == 0)return rt;
		if(val[rt] > x && op == 1)return rt;
		fint now = ch[rt][op];
		while(ch[now][op ^ 1])now = ch[now][op ^ 1];
		splay(now, 0);
		return now;
	}
	void del(int x){
		find(x);
		if(val[rt] != x)return;
		fint pre = next(x, 0), nxt = next(x, 1);
		splay(pre, 0), splay(nxt, pre);
		fint now = ch[nxt][0];
		if(now == 0)return;
		if(cnt[now] > 1){--cnt[now]; splay(now, 0); return;}
		fa[now] = ch[nxt][0] = 0;
		splay(nxt, 0);
	}
	int kth(int k){
		++k;
		fint now = rt;
		if(size[rt] < k)return val[next(INF, 0)];
		while(1){
			if(size[ch[now][0]] >= k){
				now = ch[now][0];
			}else
			if(size[ch[now][0]] + cnt[now] < k){
				k -= size[ch[now][0]] + cnt[now];
				now = ch[now][1];
			}else{splay(now, 0); return val[now];}
		}
	}
	int rk(int x){
		find(x);
		return size[ch[rt][0]] + ((val[rt] < x) ? cnt[rt] : 0);
	}
	int m, las = 0, all, a[maxn];
	signed main(){
		n = read(), m = read();
		insert(-INF), insert(INF);
		for(fint i = 1; i <= n; i++)a[i] = read();
		random_shuffle(a + 1, a + 1 + n);
		for(fint i = 1; i <= n; i++)insert(a[i]);
		while(m--){
			fint opt = read(), x = read() ^ las;
			if(opt == 1){insert(x);}
			if(opt == 2){del(x);}
			if(opt == 3){(las = rk(x));}
			if(opt == 4){(las = kth(x));}
			if(opt == 5){(las = val[next(x, 0)]);}
			if(opt == 6){(las = val[next(x, 1)]);}
			if(opt > 2){all ^= las;}
		}
		cout << all;
		return 0;
	}
}
signed main(){
    return infinities::main();
}
posted @ 2023-03-29 20:00  infinities  阅读(13)  评论(0编辑  收藏  举报