返回顶部

洛谷 P3369 【模板】普通平衡树

有旋Treap模板

//pointer version
#include <bits/stdc++.h>

using namespace std;

struct Node {
	Node *ch[2];
	int val, rank;
	int rep_cnt;
	int siz;

	Node(int val) : val(val), rep_cnt(1), siz(1) {
		ch[0] = ch[1] = nullptr;
		rank = rand();
	}

	void upd_siz() {
		siz = rep_cnt;
		if (ch[0] != nullptr) siz += ch[0]->siz;
		if (ch[1] != nullptr) siz += ch[1]->siz;
	}
};

enum rot_type {LF = 1, RT = 0};

void _rotate(Node *&cur, rot_type dir) {
	Node *tmp = cur->ch[dir];
	
	cur->ch[dir] = tmp->ch[!dir];
	tmp->ch[!dir] = cur;
	tmp->upd_siz();
	cur->upd_siz();
	cur = tmp;
}

void _insert(Node *&cur, int val) {
	if (cur == nullptr) {
		cur = new Node(val);
		return;
	}
	else if (cur->val == val) {
		cur->rep_cnt++;
		cur->siz++;
		return;
	}
	else if (cur->val > val) {
		_insert(cur->ch[0], val);
		if (cur->ch[0]->rank < cur->rank) {
			_rotate(cur, RT);
		}
		cur->upd_siz();
	}
	else if (cur->val < val) {
		_insert(cur->ch[1], val);
		if (cur->ch[1]->rank > cur->rank) {
			_rotate(cur, LF);
		}
		cur->upd_siz();
	}
}

void _del(Node *&cur, int val) {
	if (cur->val > val) {
		_del(cur->ch[0], val);
		cur->upd_siz();
	}
	else if (cur->val < val) {
		_del(cur->ch[1], val);
		cur->upd_siz();
	}
	else {
		if (cur->rep_cnt > 1) {
			cur->rep_cnt--;
			cur->siz--;
			return;
		}
		uint8_t state = 0;
		state |= (cur->ch[0] != nullptr);
		state |= ((cur->ch[1] != nullptr) << 1);
		//00:none	01:has left		10:has right	11:both
		Node *tmp = cur;
		switch(state) {
			case 0:
				delete cur;
				cur = nullptr;
				break;
			case 1:
				cur = tmp->ch[0];
				delete tmp;
				cur->upd_siz();
				break;
			case 2:
				cur = tmp->ch[1];
				delete tmp;
				cur->upd_siz();
				break;
			case 3:
				rot_type dir = cur->ch[0]->rank < cur->ch[1]->rank ? RT : LF;
				_rotate(cur, dir);
				_del(cur->ch[!dir], val);
				cur->upd_siz();
				break;
		}
	}
}

int _query_rank(Node *&cur, int val) {
	if (cur == nullptr) return 1;
	int less_siz = cur->ch[0] == nullptr ? 0 : cur->ch[0]->siz;
	if (val == cur->val) return less_siz + 1;
	else if (cur->val > val) {
		if (cur->ch[0] != nullptr) return _query_rank(cur->ch[0], val);
		else return 1;
	}
	else {
		if (cur->ch[1] != nullptr) return _query_rank(cur->ch[1], val) + less_siz + cur->rep_cnt;
		else return cur->siz + 1;
	}
}

int _query_val(Node *&cur, int rank) {
	if (cur == nullptr) return 0;
	int less_siz = cur->ch[0] == nullptr ? 0 : cur->ch[0]->siz;
	if (less_siz >= rank) return _query_val(cur->ch[0], rank);
	else if (less_siz + cur->rep_cnt >= rank) return cur->val;
	else return _query_val(cur->ch[1], rank - less_siz - cur->rep_cnt);
}

int q_pre_tmp;

int _query_prev(Node *cur, int val) {
	if (cur->val >= val) {
		if (cur->ch[0] != nullptr) return _query_prev(cur->ch[0], val); 
	}
	else {
		//we update the value of q_pre_tmp, only if we entered the else branch.
		q_pre_tmp = cur->val;
		if (cur->ch[1] != nullptr) _query_prev(cur->ch[1], val);
		return q_pre_tmp;		
		//we return the cur->val that entered the else branch the last time, wihch make sure that q_pre_tmp is the biggest valid value.
	}
	return -1;
}

int q_suf_tmp;

int _query_sufv(Node *cur, int val) {
	if (cur->val <= val) {
		if (cur->ch[1] != nullptr) return _query_sufv(cur->ch[1], val);
	}
	else {
		q_suf_tmp = cur->val;
		if (cur->ch[0] != nullptr) _query_sufv(cur->ch[0], val);
		return q_suf_tmp;
	}
	return -1;
}

int main() {
	int n;
	cin >> n;
	Node *root = nullptr;
	while (n--) {
		int op, x;
		cin >> op >> x;
		if (op == 1) _insert(root, x);
		if (op == 2) _del(root, x);
		if (op == 3) cout << _query_rank(root, x) << endl;
		if (op == 4) cout << _query_val(root, x) << endl;
		if (op == 5) cout << _query_prev(root, x) << endl;
		if (op == 6) cout << _query_sufv(root, x) << endl;
	}
	return 0;
}


//array version
#include <bits/stdc++.h>

using namespace std;

const int N = 1e5 + 1;

int n;
int rnd[N];
int val[N], cnt[N], siz[N];
int child[N][2];
int tot;

void pushup(int o) {
	siz[o] = siz[child[o][0]] + siz[child[o][1]] + cnt[o];
}

void rotate(int &o, int d) {
	int u = child[o][d ^ 1];
	child[o][d ^ 1] = child[u][d];
	child[u][d] = o;
	o = u;
	pushup(child[o][d]);
	pushup(o);
}

void insert(int &o, int v) {
	if (!o) {
		val[o = ++tot] = v;
		rnd[o] = rand();
		cnt[o] = siz[o] = 1;
		child[o][0] = child[o][1] = 0;
		return;
	}
	if (val[o] == v) {
		cnt[o]++;
		pushup(o);
		return;
	}
	else if (val[o] > v) {
		insert(child[o][0], v);
		if (rnd[o] < rnd[child[o][0]]) rotate(o, 1);
	}
	else if (val[o] < v) {
		insert(child[o][1], v);
		if (rnd[o] < rnd[child[o][1]]) rotate(o, 0);
	}
	pushup(o);
}

//here is different from BST, we keep rotating the target node until it becomes the leaf node and then delete it.
void del(int &o, int v) {
	if (val[o] == v) {
		if (cnt[o] > 1) {
			cnt[o]--;
			pushup(o);
			return;
		}
		if (child[o][0] && child[o][1]) {		//both children exist.
			if (rnd[child[o][0]] > rnd[child[o][1]]) {
				rotate(o, 1);
				del(child[o][1], v);
			}
			else {
				rotate(o, 0);
				del(child[o][0], v);
			}
		}
		else if (child[o][0]) {		//only left child exists
			rotate(o, 1);
			del(child[o][1], v);
		}
		else if (child[o][1]) {		//only right child 
			rotate(o, 0);
			del(child[o][0], v);
		}
		else o = 0;		//leaf node
		pushup(o);
		return;
	}
	if (val[o] > v) del(child[o][0], v);
	else if (val[o] < v) del(child[o][1], v);
	pushup(o);
}

int queryrank(int o, int v) {
	if (!o) return 1;
	if (val[o] == v) return siz[child[o][0]] + 1;
	else if (val[o] > v) return queryrank(child[o][0], v);
	else if (val[o] < v) return queryrank(child[o][1], v) + siz[child[o][0]] + cnt[o];
}

int querykth(int o, int k) {
	if (!o) return 0;
	else if (siz[child[o][0]] >= k) return querykth(child[o][0], k);
	else if (siz[child[o][0]] + cnt[o] < k) return querykth(child[o][1], k - siz[child[o][0]] - cnt[o]);
	return val[o];
}

int find_pre(int o, int x) {
	int res = 0;
	while (o) {
		if (val[o] < x) res = val[o], o = child[o][1];
		else o = child[o][0];
	}
	return res;
}

int find_suf(int o, int x) {
	int res = 0;
	while (o) {
		if (val[o] > x) res = val[o], o = child[o][0];
		else o = child[o][1];
	}
	return res;
}


int main () {
	//freopen("D:\\1.out", "w", stdout);
	int n;
	int root = 0;
	cin >> n;
	while (n--) {
		int op, x;
		cin >> op >> x;
		if (op == 1) insert(root, x);
		if (op == 2) del(root, x);
		if (op == 3) cout << queryrank(root, x) << endl;
		if (op == 4) cout << querykth(root, x) << endl;
		if (op == 5) cout << find_pre(root, x) << endl;
		if (op == 6) cout << find_suf(root, x) << endl;
	}
	return 0;
}
posted @ 2023-05-05 14:21  Rayotaku  阅读(13)  评论(0编辑  收藏  举报