Splay详解

因为博主太懒,所以这篇博客咕了。

【模板】普通平衡树(Splay)

#include <bits/stdc++.h>
using namespace std;

const int _ = 1e5 + 10;
const int INF = 0x3f3f3f3f;

struct node {
	int fa, ch[2];
	int val, cnt, siz;
} tr[_];
int root, tot = 0, N;

int alloc(int val, int fa) {
	tr[++tot].fa = fa;
	tr[tot].val = val;
	tr[tot].siz = tr[tot].cnt = 1;
	return tot;
}

inline void update(int x) {
	tr[x].siz = tr[tr[x].ch[0]].siz + tr[tr[x].ch[1]].siz + tr[x].cnt;
}

inline int ident(int x) { return tr[tr[x].fa].ch[1] == x; }

inline void connect(int x, int fa, int how) {
	tr[x].fa = fa;
	tr[fa].ch[how] = x;
}

void rotate(int x) {
	int y = tr[x].fa, z = tr[y].fa;
	if (y == root) root = x;
	int yson = ident(x), zson = ident(y);
	int k = tr[x].ch[yson ^ 1];
	connect(k, y, yson);
	connect(y, x, yson ^ 1);
	connect(x, z, zson);
	update(y), update(x);
}

void splay(int x, int to) {
	while (tr[x].fa != to) {
		int y = tr[x].fa, z = tr[y].fa;
		if (tr[y].fa != to)
			(tr[z].ch[0] == y) ^ (tr[y].ch[0] == x) ? rotate(x) : rotate(y);
		rotate(x);
	}
	if (!to) root = x;
}

void insert(int x) {
	int u = root;
	if (!u) {
		root = alloc(x, 0);
		return;
	}
	while (1) {
		++tr[u].siz;
		if (tr[u].val == x) {
			++tr[u].cnt;
			splay(u, 0);
			return;
		}
		int nxt = x > tr[u].val;
		if (!tr[u].ch[nxt]) {
			int p = alloc(x, u);
			tr[u].ch[nxt] = p;
			splay(p, 0);
			return;
		}
		u = tr[u].ch[nxt];
	}
}

int find(int val) {
	int u = root;
	while (1) {
		if (tr[u].val == val) {
			splay(u, 0);
			return u;
		}
		int nxt = val > tr[u].val;
		if (!tr[u].ch[nxt]) return 0;
		u = tr[u].ch[nxt];
	}
}

void remove(int x) {
	int pos = find(x);
	if (!pos) return;
	if (tr[pos].cnt > 1) {
		--tr[pos].cnt, --tr[pos].siz;
		return;
	}
	if (!tr[pos].ch[0] && !tr[pos].ch[1]) root = 0;
	else if (!tr[pos].ch[0]) {
		root = tr[pos].ch[1];
		tr[root].fa = 0;
	} else {
		int u = tr[pos].ch[0];
		while (tr[u].ch[1]) u = tr[u].ch[1];
		splay(u, root);
		connect(tr[pos].ch[1], u, 1);
		connect(u, 0, 1);
		root = u;
		update(u);
	}
}

int getrank(int val) {
	int pos = find(val);
	return tr[tr[pos].ch[0]].siz + 1;
}

int kth(int x) {
	int u = root;
	while (1) {
		int rest = tr[u].siz - tr[tr[u].ch[1]].siz;
		if (x > tr[tr[u].ch[0]].siz && x <= rest) {
			splay(u, 0);
			return tr[u].val;
		}
		if (x < rest) u = tr[u].ch[0];
		else x -= rest, u = tr[u].ch[1];
	}
}

int getpre(int val) {
	int u = root;
	int ans = -INF;
	while (u) {
		if (tr[u].val < val && tr[u].val > ans) ans = tr[u].val;
		if (val > tr[u].val) u = tr[u].ch[1];
		else u = tr[u].ch[0];
	}
	return ans;
}


int getnxt(int val) {
	int u = root;
	int ans = INF;
	while (u) {
		if (tr[u].val > val && tr[u].val < ans) ans = tr[u].val;
		if (val >= tr[u].val) u = tr[u].ch[1];
		else u = tr[u].ch[0];
	}
	return ans;
}

int main() {
#ifndef ONLINE_JUDGE
	freopen("splay.in", "r", stdin);
	freopen("splay.out", "w", stdout);
#endif
	scanf("%d", &N);
	while (N--) {
		int op, x;
		scanf("%d%d", &op, &x);
		if (op == 1) insert(x);
		else if (op == 2) remove(x);
		else if (op == 3) printf("%d\n", getrank(x));
		else if (op == 4) printf("%d\n", kth(x));
		else if (op == 5) printf("%d\n", getpre(x));
		else if (op == 6) printf("%d\n", getnxt(x));
	}
	return 0;
}

【模板】文艺平衡树(Splay区间修改)

#include <bits/stdc++.h>
using namespace std;

inline int ty() {
	char ch = getchar(); int x = 0, f = 1;
	while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
	while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
	return x * f;
}

const int _ = 1e5 + 10;
const int INF = 0x3f3f3f3f;

struct node {
	int fa, ch[2];
	int val, cnt, siz, tag;
} tr[_];
int root = 1, tot = 0, N, M, num[_];

inline void update(int x) {
	tr[x].siz = tr[tr[x].ch[0]].siz + tr[tr[x].ch[1]].siz + tr[x].cnt;
}

inline int ident(int x) { return tr[tr[x].fa].ch[1] == x; }

inline void connect(int x, int fa, int how) {
	tr[x].fa = fa;
	tr[fa].ch[how] = x;
}

inline void pushdown(int x) {
	if (x && tr[x].tag) {
		tr[tr[x].ch[0]].tag ^= 1;
		tr[tr[x].ch[1]].tag ^= 1;
		swap(tr[x].ch[0], tr[x].ch[1]);
		tr[x].tag = 0;
	}
}

int build(int l, int r, int fa) {
	if (l > r) return 0;
	int mid = (l + r) >> 1;
	int cur = ++tot;
	tr[cur].fa = fa;
	tr[cur].val = num[mid];
	++tr[cur].cnt, ++tr[cur].siz;
	tr[cur].tag = 0;
	tr[cur].ch[0] = build(l, mid - 1, cur);
	tr[cur].ch[1] = build(mid + 1, r, cur);
	update(cur);
	return cur;
}

void rotate(int x) {
	int y = tr[x].fa, z = tr[y].fa;
	if (y == root) root = x;
	int yson = ident(x), zson = ident(y);
	int k = tr[x].ch[yson ^ 1];
	connect(k, y, yson);
	connect(y, x, yson ^ 1);
	connect(x, z, zson);
	update(y), update(x);
}

void splay(int x, int to) {
	while (tr[x].fa != to) {
		int y = tr[x].fa, z = tr[y].fa;
		if (tr[y].fa != to)
			(tr[z].ch[0] == y) ^ (tr[y].ch[0] == x) ? rotate(x) : rotate(y);
		rotate(x);
	}
	if (!to) root = x;
}

int find(int x) {
	int u = root;
	while (1) {
		pushdown(u);
		int rest = tr[u].siz - tr[tr[u].ch[1]].siz;
		if (x > tr[tr[u].ch[0]].siz && x <= rest) {
			splay(u, 0);
			return u;
		}
		if (x < rest) u = tr[u].ch[0];
		else x -= rest, u = tr[u].ch[1];
	}
}

void rever(int l, int r) {
	l = find(l - 1), r = find(r + 1);
	splay(l, 0);
	splay(r, l);
	int pos = tr[root].ch[1];
	pos = tr[pos].ch[0];
	tr[pos].tag ^= 1;
}

void dfs(int x) {
	pushdown(x);
	if (tr[x].ch[0]) dfs(tr[x].ch[0]);
	if (tr[x].val != INF && tr[x].val != -INF) printf("%d ", tr[x].val);
	if (tr[x].ch[1]) dfs(tr[x].ch[1]);
}

int main() {
#ifndef ONLINE_JUDGE
	freopen("splay.in", "r", stdin);
	freopen("splay.out", "w", stdout);
#endif
	N = ty(), M = ty();
	num[1] = -INF, num[N + 2] = INF;
	for (int i = 1; i <= N; ++i) num[i + 1] = i;
	build(1, N + 2, 0);
	while (M--) {
		int l = ty(), r = ty();
		rever(l + 1, r + 1);
	}
	dfs(root);
	return 0;
}
posted @ 2019-12-28 16:35  newbielyx  阅读(149)  评论(0编辑  收藏  举报