AcWing 253 普通平衡树 (treap)

题目链接:https://www.acwing.com/problem/content/submission/code_detail/2916478/

普通旋转 \(treap\) 的重要操作就是 \(zig zag\)
以及用堆的性质维持树的平衡

旋转 \(treap\)

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cmath>
#include<stack>
#include<queue>
#include<ctime>
using namespace std;
typedef long long ll;

const int maxn = 100010;
const int INF = 0x7fffffff;

int n, tot = 0, rt;

struct Treap{
	int l, r;
	int val, dat;
	int cnt, sz;
} t[maxn];

int New(int val){
	t[++tot].val = val;
	t[tot].dat = rand();
	t[tot].cnt = t[tot].sz = 1;
	return tot;
}

void pushup(int p){
	t[p].sz = t[t[p].l].sz + t[t[p].r].sz + t[p].cnt;
}

void build(){
	New(-INF), New(INF);
	rt = 1, t[1].r = 2;
	pushup(rt); 
}

void zig(int &p){
	int q = t[p].l;
	t[p].l = t[q].r, t[q].r = p;
	p = q;
	pushup(t[p].r); pushup(p);
}

void zag(int &p){
	int q = t[p].r;
	t[p].r = t[q].l, t[q].l = p;
	p = q;
	pushup(t[p].l); pushup(p);
}

void insert(int &p, int val){
	if(p == 0){
		p = New(val);
		return;
	}
	if(t[p].val == val){
		t[p].cnt++;
		pushup(p);
		return;
	}
	if(val < t[p].val){
		insert(t[p].l, val);
		if(t[p].dat < t[t[p].l].dat) zig(p);
	}
	else {
		insert(t[p].r, val);
		if(t[p].dat < t[t[p].r].dat) zag(p);
	}
	pushup(p);
}

void remove(int &p, int val){
	if(p == 0) return;
	if(t[p].val == val){
		if (t[p].cnt > 1) {
			t[p].cnt--;
			pushup(p);
			return;
		}
		if(t[p].l || t[p].r) {
			if (t[p].r == 0 || t[t[p].l].dat > t[t[p].r].dat) {
				zig(p), remove(t[p].r, val);
			}
			else {
				zag(p), remove(t[p].l, val);
			}
			pushup(p);
		}
		else p = 0; return;
	}
	val < t[p].val ? remove(t[p].l, val) : remove(t[p].r, val);
	pushup(p); 
}

int GetRankByVal(int p, int val){
	if(!p) return 0;
	if(t[p].val == val){
		return t[t[p].l].sz + 1;
	}
	if(val < t[p].val) return GetRankByVal(t[p].l, val);
	else return GetRankByVal(t[p].r, val) + t[t[p].l].sz + t[p].cnt;
}

int GetValByRank(int p, int rank){
	if(!p) return INF;
	if(t[t[p].l].sz >= rank) return GetValByRank(t[p].l, rank);
	if(t[t[p].l].sz + t[p].cnt >= rank) return t[p].val;
	else return GetValByRank(t[p].r, rank - t[t[p].l].sz - t[p].cnt);
}

int GetPre(int val){
	int ans = 1; // -INF
	int p = rt;
	while(p){
		if(t[p].val == val){
			if(t[p].l){
				p = t[p].l;
				while(t[p].r) p = t[p].r;
				ans = p;
			}
			break;
		}
		if(t[p].val < val && t[p].val > t[ans].val) ans = p;
		p = val < t[p].val ? t[p].l : t[p].r;
	}
	return t[ans].val;
}

int GetNext(int val){
	int ans = 2; // INF
	int p = rt;
	while(p){
		if(t[p].val == val){
			if(t[p].r){
				p = t[p].r;
				while(t[p].l) p = t[p].l;
				ans = p;
			}
			break;
		}
		if(t[p].val > val && t[p].val < t[ans].val) ans = p;
		p = val < t[p].val ? t[p].l : t[p].r;
	}
	return t[ans].val;
}

ll read(){ ll s=0,f=1; char ch=getchar(); while(ch<'0' || ch>'9'){ if(ch=='-') f=-1; ch=getchar(); } while(ch>='0' && ch<='9'){ s=s*10+ch-'0'; ch=getchar(); } return s*f; }

int main(){
	srand(time(0)); 
	build();
	n = read();
	
	int op, x;
	for(int i = 1; i <= n; ++i){
		op = read(), x = read();
		switch (op){
			case 1:
				insert(rt, x);
				break;
			case 2:
				remove(rt, x);
				break;
			case 3:
				printf("%d\n", GetRankByVal(rt, x) - 1);
				break;
			case 4:
				printf("%d\n", GetValByRank(rt, x + 1));
				break;
			case 5:
				printf("%d\n", GetPre(x));
				break;
			case 6:
				printf("%d\n", GetNext(x));
				break; 
		}
	}
	
	return 0;
}
posted @ 2020-11-12 22:35  Tartarus_li  阅读(78)  评论(0编辑  收藏  举报