返回顶部

【模板】 Splay

splay

#include <bits/stdc++.h>

using namespace std;

const int N = 5e6 + 10;

int val[N], cnt[N], fa[N], ch[N][2], siz[N];
int tot, root;

void maintain(int x) {
	siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + cnt[x];
}

bool get(int x) {
	return x == ch[fa[x]][1];
}

void clear(int x) {
	ch[x][0] = ch[x][1] = fa[x] = val[x] = siz[x] = cnt[x] = 0;
}

void rotate(int x) {
	int y = fa[x], z = fa[y], chk = get(x);
	ch[y][chk] = ch[x][chk ^ 1];
	if (ch[x][chk ^ 1]) fa[ch[x][chk ^ 1]] = y;
	ch[x][chk ^ 1] = y;
	fa[y] = x;
	fa[x] = z;
	if (z) ch[z][y == ch[z][1]] = x;
	maintain(y);
	maintain(x);
}

void splay(int x) {
	for (int f = fa[x]; f = fa[x], f; rotate(x)) {
		if (fa[f]) rotate(get(x) == get(f) ? f : x);
	}
	root = x;
}

void insert(int k) {
	if (!root) {
		val[++tot] = k;
		cnt[tot]++;
		root = tot;
		maintain(root);
		return;
	}
	int cur = root, f = 0;
	while (true) {
		if (val[cur] == k) {
			cnt[cur]++;
			maintain(cur);
			maintain(f);
			splay(cur);
			break;
		}
		f = cur;
		cur = ch[cur][val[cur] < k];
		if (!cur) {
			val[++tot] = k;
			cnt[tot]++;
			fa[tot] = f;
			ch[f][val[f] < k] = tot;
			maintain(tot);
			maintain(f);
			splay(tot);
			break; 
		}
	}
}

int rk(int k) {
	int res = 0, cur = root;
	// ※※※ 
	while (cur) {
		if (val[cur] > k) {
			cur = ch[cur][0];
		}
		else {
			res += siz[ch[cur][0]];
			if (val[cur] == k) {
				splay(cur);
				return res + 1;
			}
			res += cnt[cur];
			cur = ch[cur][1];
		}
	}
	// ※※※
	return res + 1;
}

int kth(int k) {
	int cur = root;
	if (!cur) return 0;
	while (true) {
		if (ch[cur][0] && siz[ch[cur][0]] >= k) {
			cur = ch[cur][0];
		}
		else {
			k -= cnt[cur] + siz[ch[cur][0]];
			if (k <= 0) {
				splay(cur);
				return val[cur];
			}
			cur = ch[cur][1];
		}
	}
}

//insert x first, x is root now, just go find it and delete x at last
int pre() { int cur = ch[root][0]; if (!cur) return 0; while (ch[cur][1]) cur
= ch[cur][1]; splay(cur); return cur; }

int suf() {
	int cur = ch[root][1];
	if (!cur) return 0;
	while (ch[cur][0]) cur = ch[cur][0];
	splay(cur);
	return cur;
}

void del(int k) {
	rk(k);		//find the pos of k and splay it to the root
	if (cnt[root] > 1) {
		cnt[root]--;
		maintain(root);
		return;
	}
	if (!ch[root][0] && !ch[root][1]) {
		clear(root);
		root = 0;
		return;
	}
	if (!ch[root][0]) {
		int cur = root;
		root = ch[root][1];
		fa[root] = 0;
		clear(cur);
		return;
	}
	if (!ch[root][1]) {
		int cur = root;
		root = ch[root][0];
		fa[root] = 0;
		clear(cur);
		return;
	}
	int cur = root, x = pre();
	fa[ch[cur][1]] = x;
	ch[x][1] = ch[cur][1];
	clear(cur);
	maintain(root);
}

int main() {
	int n, m;
	scanf("%d %d", &n, &m);
	for (int i = 1; i <= n; ++i) {
		int x;
		scanf("%d", &x);
		insert(x);
	}
	int last = 0, res = 0;
	while (m--) {
		int op, x;
		scanf("%d %d", &op, &x);
		x ^= last;
		if (op == 1) insert(x);
		if (op == 2) del(x);
		if (op == 3) last = rk(x), res ^= last;
		if (op == 4) last = kth(x), res ^= last;
		if (op == 5) {
			insert(x);
			last = val[pre()];
			res ^= last;
			del(x);
		}
		if (op == 6) {
			insert(x);
			last = val[suf()];
			res ^= last;
			del(x);
		}
	}
	printf("%d\n", res);
	/*
	int n;
	cin >> n;
	//Node *root = nullptr;
	while (n--) {
		int op, x;
		cin >> op >> x;
		if (op == 1) insert(x);
		if (op == 2) del(x);
		if (op == 3) cout << rk(x) << endl;
		if (op == 4) cout << kth(x) << endl;
		if (op == 5) {
			insert(x);
			cout << val[pre()] << endl;
			del(x);
		}
		if (op == 6) {
			insert(x);
			cout << val[suf()] << endl;
			del(x);
		}
	}
	*/
	return 0;
}
posted @ 2023-05-06 17:26  Rayotaku  阅读(16)  评论(0编辑  收藏  举报