Loading

Splay树

P3369 【模板】普通平衡树

image

image

image

image

image

image

image

image

image

image

image

image

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 100010, INF = 1e9;

struct Node {
	int ch[2];
	int p, v;
	int size;
	int cnt;
}tr[N];

int root, idx;

void pushup(int u) {
	tr[u].size = tr[tr[u].ch[0]].size + tr[tr[u].ch[1]].size + tr[u].cnt;
}

void rotate(int x) {
	int y = tr[x].p;
	int z = tr[y].p;
	int k = (tr[y].ch[1] == x);
	
	tr[y].ch[k] = tr[x].ch[k ^ 1], tr[tr[x].ch[k ^ 1]].p = y;
	tr[x].ch[k ^ 1] = y, tr[y].p = x;
	tr[z].ch[tr[z].ch[1] == y] = x, tr[x].p = z;
	
	pushup(y);
	pushup(x);
}

void splay(int x, int k) {
	while (tr[x].p != k) {
		int y = tr[x].p;
		int z = tr[y].p;
		
		if (z != k) {
			if ((tr[z].ch[1] == y) ^ (tr[y].ch[1] == x)) rotate(x);
			else rotate(y);
		}
		
		rotate(x);
	}
	
	if (!k) root = x;
}

void insert(int x) {
	int u = root, p = 0;
	while (u && tr[u].v != x) {
		p = u;
		u = tr[u].ch[x > tr[u].v];
	}
	if (u) tr[u].cnt++;
	else {
		u = ++idx;
		if (p) tr[p].ch[x > tr[p].v] = u;
		tr[u].p = p;
		tr[u].v = x;
		tr[u].size = 1;
		tr[u].cnt = 1;
	}
	splay(u, 0);
}

void find(int x) {
	int u = root;
	while (tr[u].ch[x > tr[u].v] && tr[u].v != x) u = tr[u].ch[x > tr[u].v];
	splay(u, 0);
}

int get_pre(int x) {
	find(x);
	if (tr[root].v < x) return root;
	int u = tr[root].ch[0];
	while (tr[u].ch[1]) u = tr[u].ch[1];
	splay(u, 0);
	return u;
}

int get_suc(int x) {
	find(x);
	if (tr[root].v > x) return root;
	int u = tr[root].ch[1];
	while (tr[u].ch[0]) u = tr[u].ch[0];
	splay(u, 0);
	return u;
}

void remove(int x) {
	int pre = get_pre(x);
	int suc = get_suc(x);
	
	splay(pre, 0);
	splay(suc, pre);
	
	int del = tr[suc].ch[0];
	if (tr[del].cnt > 1) tr[del].cnt--, splay(del, 0);
	else tr[suc].ch[0] = 0, splay(suc, 0);
}

int get_rank_by_key(int x) {
	insert(x);
	int res = tr[tr[root].ch[0]].size;
	remove(x);
	return res;
}

int get_key_by_rank(int k) {
	int u = root;
	while (true) {
		if (k <= tr[tr[u].ch[0]].size) u = tr[u].ch[0];
		else if (k <= tr[tr[u].ch[0]].size + tr[u].cnt) break;
		else k -= tr[tr[u].ch[0]].size + tr[u].cnt, u = tr[u].ch[1];
	}
	splay(u, 0);
	return tr[u].v;
}

int main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	
	insert(-INF);
	insert(INF);
	
	int T;
	cin >> T;
	while (T--) {
		int opt, x;
		cin >> opt >> x;
		if (opt == 1) insert(x);
		else if (opt == 2) remove(x);
		else if (opt == 3) cout << get_rank_by_key(x) << '\n';
		else if (opt == 4) cout << get_key_by_rank(x + 1) << '\n';
		else if (opt == 5) cout << tr[get_pre(x)].v << '\n';
		else cout << tr[get_suc(x)].v << '\n';
	}
	
	return 0;
}

P6136 【模板】普通平衡树(数据加强版)

注意数组要开到 \(1.1 \times 10 ^ 6\)

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 1100010, INF = (1 << 30) + 1;

struct Node {
	int ch[2];
	int p, v;
	int size;
	int cnt;
}tr[N];

int root, idx;

void pushup(int u) {
	tr[u].size = tr[tr[u].ch[0]].size + tr[tr[u].ch[1]].size + tr[u].cnt;
}

void rotate(int x) {
	int y = tr[x].p;
	int z = tr[y].p;
	int k = (tr[y].ch[1] == x);
	
	tr[y].ch[k] = tr[x].ch[k ^ 1], tr[tr[x].ch[k ^ 1]].p = y;
	tr[x].ch[k ^ 1] = y, tr[y].p = x;
	tr[z].ch[tr[z].ch[1] == y] = x, tr[x].p = z;
	
	pushup(y);
	pushup(x);
}

void splay(int x, int k) {
	while (tr[x].p != k) {
		int y = tr[x].p;
		int z = tr[y].p;
		
		if (z != k) {
			if ((tr[z].ch[1] == y) ^ (tr[y].ch[1] == x)) rotate(x);
			else rotate(y);
		}
		
		rotate(x);
	}
	
	if (!k) root = x;
}

void insert(int x) {
	int u = root, p = 0;
	while (u && tr[u].v != x) {
		p = u;
		u = tr[u].ch[x > tr[u].v];
	}
	if (u) tr[u].cnt++;
	else {
		u = ++idx;
		if (p) tr[p].ch[x > tr[p].v] = u;
		tr[u].p = p;
		tr[u].v = x;
		tr[u].size = 1;
		tr[u].cnt = 1;
	}
	splay(u, 0);
}

void find(int x) {
	int u = root;
	while (tr[u].ch[x > tr[u].v] && tr[u].v != x) u = tr[u].ch[x > tr[u].v];
	splay(u, 0);
}

int get_pre(int x) {
	find(x);
	if (tr[root].v < x) return root;
	int u = tr[root].ch[0];
	while (tr[u].ch[1]) u = tr[u].ch[1];
	splay(u, 0);
	return u;
}

int get_suc(int x) {
	find(x);
	if (tr[root].v > x) return root;
	int u = tr[root].ch[1];
	while (tr[u].ch[0]) u = tr[u].ch[0];
	splay(u, 0);
	return u;
}

void remove(int x) {
	int pre = get_pre(x);
	int suc = get_suc(x);
	
	splay(pre, 0);
	splay(suc, pre);
	
	int del = tr[suc].ch[0];
	if (tr[del].cnt > 1) tr[del].cnt--, splay(del, 0);
	else tr[suc].ch[0] = 0, splay(suc, 0);
}

int get_rank_by_key(int x) {
	insert(x);
	int res = tr[tr[root].ch[0]].size;
	remove(x);
	return res;
}

int get_key_by_rank(int k) {
	int u = root;
	while (true) {
		if (k <= tr[tr[u].ch[0]].size) u = tr[u].ch[0];
		else if (k <= tr[tr[u].ch[0]].size + tr[u].cnt) break;
		else k -= tr[tr[u].ch[0]].size + tr[u].cnt, u = tr[u].ch[1];
	}
	splay(u, 0);
	return tr[u].v;
}

int main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	
	insert(-INF);
	insert(INF);
	
	int n, T;
	cin >> n >> T;
	for (int i = 1; i <= n; i++) {
		int x;
		cin >> x;
		insert(x);
	}
	
	int res = 0, last = 0;
	
	while (T--) {
		int opt, x;
		cin >> opt >> x;
		x ^= last;
		if (opt == 1) insert(x);
		else if (opt == 2) remove(x);
		else if (opt == 3) res ^= (last = get_rank_by_key(x));
		else if (opt == 4) res ^= (last = get_key_by_rank(x + 1));
		else if (opt == 5) res ^= (last = tr[get_pre(x)].v);
		else res ^= (last = tr[get_suc(x)].v);
	}
	
	cout << res << '\n';
	
	return 0;
}

AcWing 2437 Splay / P3391 【模板】文艺平衡树

image

image

image

以下是代码:

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 100010;

struct Node {
	int ch[2];
	int size;
	int p, v;
	int flag;
}tr[N];

int root, idx;

void pushup(int u) {
	tr[u].size = tr[tr[u].ch[0]].size + tr[tr[u].ch[1]].size + 1;
}

void pushdown(int u) {
	if (tr[u].flag) {
		swap(tr[u].ch[0], tr[u].ch[1]);
		tr[tr[u].ch[0]].flag ^= 1;
		tr[tr[u].ch[1]].flag ^= 1;
		tr[u].flag ^= 1;
	}
}

void rotate(int x) {
	int y = tr[x].p;
	int z = tr[y].p;
	int k = tr[y].ch[1] == x;
	
	tr[z].ch[tr[z].ch[1] == y] = x, tr[x].p = z;
	tr[y].ch[k] = tr[x].ch[k ^ 1], tr[tr[x].ch[k ^ 1]].p = y;
	tr[x].ch[k ^ 1] = y, tr[y].p = x;
	
	pushup(y);
	pushup(x);
}

void splay(int x, int k) {
	while (tr[x].p != k) {
		int y = tr[x].p;
		int z = tr[y].p;
		if (z != k) {
			if ((tr[z].ch[1] == y) ^ (tr[y].ch[1] == x)) rotate(x);
			else rotate(y);
		}
		rotate(x);
	}
	if (!k) root = x;
}

int n, m;

void insert(int x) {
	int u = root, p = 0;
	while (u) p = u, u = tr[u].ch[x > tr[u].v];
	u = ++idx;
	if (p) tr[p].ch[x > tr[p].v] = u;
	tr[u].v = x;
	tr[u].size = 1;
	tr[u].p = p;
	tr[u].flag = 0;
	splay(u, 0);
}

int get_k(int k) {
	int p = root;
	while (true) {
		pushdown(p);
		if (k <= tr[tr[p].ch[0]].size) p = tr[p].ch[0];
		else if (k == tr[tr[p].ch[0]].size + 1) {
		    splay(p, 0);
		    return p;
		}
		else k -= tr[tr[p].ch[0]].size + 1, p = tr[p].ch[1];
	}
	return -1;
}

void output(int u) {
	pushdown(u);
	if (tr[u].ch[0]) output(tr[u].ch[0]);
	if (tr[u].v >= 1 && tr[u].v <= n) cout << tr[u].v << ' ';
	if (tr[u].ch[1]) output(tr[u].ch[1]);
}

int main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	
	cin >> n >> m;
	for (int i = 0; i <= n + 1; i++) insert(i);
	
	while (m--) {
		int x, y, l, r;
		cin >> x >> y;
		l = get_k(x), r = get_k(y + 2);
		splay(l, 0);
		splay(r, l);
		tr[tr[r].ch[0]].flag ^= 1;
	}
	output(root);
	return 0;
}

CF675D Tree Construction

实际上,\(x\) 的父结点就是 \(x\) 的前驱和后继中编号较大的那一个。

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 100010, INF = 1000000000;

struct Node {
	int ch[2];
	int size;
	int p, v;
	int cnt;
	int kth; 
}tr[N];

int root, idx;

void pushup(int u) {
	tr[u].size = tr[tr[u].ch[0]].size + tr[tr[u].ch[1]].size + tr[u].cnt;
}

void rotate(int x) {
	int y = tr[x].p;
	int z = tr[y].p;
	int k = tr[y].ch[1] == x;
	
	tr[z].ch[tr[z].ch[1] == y] = x, tr[x].p = z;
	tr[y].ch[k] = tr[x].ch[k ^ 1], tr[tr[x].ch[k ^ 1]].p = y;
	tr[x].ch[k ^ 1] = y, tr[y].p = x;
	
	pushup(y);
	pushup(x);
}

void splay(int x, int k) {
	while (tr[x].p != k) {
		int y = tr[x].p;
		int z = tr[y].p;
		
		if (z != k) {
			if ((tr[z].ch[1] == y) ^ (tr[y].ch[1] == x)) rotate(x);
			else rotate(y);
		}
		rotate(x);
	}
	if (!k) root = x;
}

void find(int x) {
	int u = root;
	while (tr[u].ch[x > tr[u].v] && tr[u].v != x) {
		u = tr[u].ch[x > tr[u].v];
	}
	splay(u, 0);
}

int get_pre(int x) {
	find(x);
	if (tr[root].v < x) return root;
	int u = tr[root].ch[0];
	while (tr[u].ch[1]) u = tr[u].ch[1];
	splay(u, 0);
	return u;
}

int get_suc(int x) {
	find(x);
	if (tr[root].v > x) return root;
	int u = tr[root].ch[1];
	while (tr[u].ch[0]) u = tr[u].ch[0];
	splay(u, 0);
	return u;
}

void insert(int x, int k) {
	int p = 0, u = root;
	while (u && tr[u].v != x)
		p = u, u = tr[u].ch[x > tr[u].v];
	if (u) tr[u].cnt++;
	else {
		u = ++idx;
		if (p) tr[p].ch[x > tr[p].v] = u;
		tr[u].p = p;
		tr[u].size = tr[u].cnt = 1;
		tr[u].v = x;
		tr[u].kth = k;
	}
	splay(u, 0);
}

int main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	
	insert(-INF, -INF);
	insert(INF, -INF);
	
	int T, x;
	cin >> T >> x;
	insert(x, 1);
	
	for (int i = 2; i <= T; i++) {
		int x;
		cin >> x;
		int pre = get_pre(x);
		int suc = get_suc(x);
		int node = tr[pre].kth > tr[suc].kth ? pre : suc;
		cout << tr[node].v << ' ';
		insert(x, i);
	}
	return 0;
} 

P2286 [HNOI2004]宠物收养场

  • 只有宠物的时候,拿一个平衡树维护宠物的信息。
    • 那么当一个人来的时候,可以找他的特点值 \(x\) 的前驱和后继,看它们谁与 \(x\) 的差距更小就取谁。(如果相同取前驱)
  • 只有人的时候,拿一个平衡树维护人的信息。
    • 同理,当一只宠物来的时候可以用它的特点值 \(x\) 找出最适合它的主人。

代码:

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 80010, INF = 1000000000, mod = 1000000;

struct Node {
	int ch[2];
	int size;
	int p, v;
	int cnt;
}tr[N];

int root, idx;

void pushup(int u) {
	tr[u].size = tr[tr[u].ch[0]].size + tr[tr[u].ch[1]].size + tr[u].cnt;
}

void rotate(int x) {
	int y = tr[x].p;
	int z = tr[y].p;
	int k = tr[y].ch[1] == x;
	
	tr[z].ch[tr[z].ch[1] == y] = x, tr[x].p = z;
	tr[y].ch[k] = tr[x].ch[k ^ 1], tr[tr[x].ch[k ^ 1]].p = y;
	tr[x].ch[k ^ 1] = y, tr[y].p = x;
	
	pushup(y);
	pushup(x);
}

void splay(int x, int k) {
	while (tr[x].p != k) {
		int y = tr[x].p;
		int z = tr[y].p;
		
		if (z != k) {
			if ((tr[z].ch[1] == y) ^ (tr[y].ch[1] == x)) rotate(x);
			else rotate(y);
		}
		rotate(x);
	}
	if (!k) root = x;
}

void find(int x) {
	int u = root;
	while (tr[u].ch[x > tr[u].v] && tr[u].v != x) {
		u = tr[u].ch[x > tr[u].v];
	}
	splay(u, 0);
}

int get_pre(int x) {
	find(x);
	if (tr[root].v < x) return root;
	int u = tr[root].ch[0];
	while (tr[u].ch[1]) u = tr[u].ch[1];
	splay(u, 0);
	return u;
}

int get_suc(int x) {
	find(x);
	if (tr[root].v > x) return root;
	int u = tr[root].ch[1];
	while (tr[u].ch[0]) u = tr[u].ch[0];
	splay(u, 0);
	return u;
}

int _get_pre(int x) {
	find(x);
	if (tr[root].v <= x) return root;
	int u = tr[root].ch[0];
	while (tr[u].ch[1]) u = tr[u].ch[1];
	splay(u, 0);
	return u;
}


int _get_suc(int x) {
	find(x);
	if (tr[root].v >= x) return root;
	int u = tr[root].ch[1];
	while (tr[u].ch[0]) u = tr[u].ch[0];
	splay(u, 0);
	return u;
}

void insert(int x) {
	int p = 0, u = root;
	while (u && tr[u].v != x)
		p = u, u = tr[u].ch[x > tr[u].v];
	if (u) tr[u].cnt++;
	else {
		u = ++idx;
		if (p) tr[p].ch[x > tr[p].v] = u;
		tr[u].p = p;
		tr[u].size = tr[u].cnt = 1;
		tr[u].v = x;
	}
	splay(u, 0);
}

void del(int x) {
	int pre = get_pre(x);
	int suc = get_suc(x);
	splay(pre, 0);
	splay(suc, pre);
	int del_v = tr[suc].ch[0];
	if (tr[del_v].cnt > 1) tr[del_v].cnt--, splay(del_v, 0);
	else tr[suc].ch[0] = 0, splay(suc, 0);
}

int get_rank_by_key(int x) {
	find(x);
	return tr[tr[root].ch[0]].size;
}

int get_key_by_rank(int k) {
	int u = root;
	while (true) {
		if (k <= tr[tr[u].ch[0]].size) u = tr[u].ch[0];
		else if (k <= tr[tr[u].ch[0]].size + tr[u].cnt) break;
		else k -= tr[tr[u].ch[0]].size + tr[u].cnt, u = tr[u].ch[1];
	}
	splay(u, 0);
	return tr[u].v;
}

int main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	
	insert(-INF);
	insert(INF);
	
	int n, sum = 0, cnt[3] = {0, 0};
	cin >> n; 
	for (int i = 1; i <= n; i++) {
		int opt, x;
		cin >> opt >> x;
		if (cnt[opt ^ 1] <= 0) {
			cnt[opt]++; 
			insert(x);
			continue;
		}
		cnt[opt ^ 1]--;
		int pre = _get_pre(x);
		int suc = _get_suc(x);
		int a = tr[pre].v;
		int b = tr[suc].v;
		int c = abs(a - x);
		int d = abs(b - x);
		if (c <= d) del(a), sum = (sum + c) % mod;
		else del(b), sum = (sum + d) % mod;
	}
	
	cout << sum << '\n';
	return 0;
} 

P2234 [HNOI2002]营业额统计

找到前驱和后继,并作差求最小即可。

注意要特殊处理 \(1\) 号结点。

#include <iostream>
#include <cstring>
#include <algorithm>

using namespace std;

const int N = 40000, INF = 1000000000;

struct Node {
	int ch[2];
	int size;
	int p, v;
	int cnt;
}tr[N];

int root, idx;

void pushup(int u) {
	tr[u].size = tr[tr[u].ch[0]].size + tr[tr[u].ch[1]].size + tr[u].cnt;
}

void rotate(int x) {
	int y = tr[x].p;
	int z = tr[y].p;
	int k = tr[y].ch[1] == x;
	
	tr[z].ch[tr[z].ch[1] == y] = x, tr[x].p = z;
	tr[y].ch[k] = tr[x].ch[k ^ 1], tr[tr[x].ch[k ^ 1]].p = y;
	tr[x].ch[k ^ 1] = y, tr[y].p = x;
	
	pushup(y);
	pushup(x);
}

void splay(int x, int k) {
	while (tr[x].p != k) {
		int y = tr[x].p;
		int z = tr[y].p;
		
		if (z != k) {
			if ((tr[z].ch[1] == y) ^ (tr[y].ch[1] == x)) rotate(x);
			else rotate(y);
		}
		rotate(x);
	}
	if (!k) root = x;
}

void find(int x) {
	int u = root;
	while (tr[u].ch[x > tr[u].v] && tr[u].v != x) {
		u = tr[u].ch[x > tr[u].v];
	}
	splay(u, 0);
}

int get_pre(int x) {
	find(x);
	if (tr[root].v <= x) return root;
	int u = tr[root].ch[0];
	while (tr[u].ch[1]) u = tr[u].ch[1];
	splay(u, 0);
	return u;
}

int get_suc(int x) {
	find(x);
	if (tr[root].v >= x) return root;
	int u = tr[root].ch[1];
	while (tr[u].ch[0]) u = tr[u].ch[0];
	splay(u, 0);
	return u;
}

void insert(int x) {
	int p = 0, u = root;
	while (u && tr[u].v != x)
		p = u, u = tr[u].ch[x > tr[u].v];
	if (u) tr[u].cnt++;
	else {
		u = ++idx;
		if (p) tr[p].ch[x > tr[p].v] = u;
		tr[u].p = p;
		tr[u].size = tr[u].cnt = 1;
		tr[u].v = x;
	}
	splay(u, 0);
}

int sum; 

int main() {
	ios::sync_with_stdio(false);
	cin.tie(nullptr);
	
	insert(-INF);
	insert(INF);
	
	int n, x;
	cin >> n >> x;
	sum = x;
	insert(x);
	for (int i = 2; i <= n; i++) {
		int x;
		cin >> x;
		int pre = get_pre(x);
		int suc = get_suc(x);
		int c1 = abs(tr[pre].v - x);
		int c2 = abs(tr[suc].v - x);
		sum += min(c1, c2);
		insert(x);
	}
	cout << sum << '\n';
	return 0;
}
posted @ 2023-03-16 08:45  SunnyYuan  阅读(66)  评论(1编辑  收藏  举报