平衡树模版

Treap

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#define inf 0x3f3f3f3f
#define ll long long
using namespace std;
const int MAXN = 100005;
int init() {
	int rv = 0, fh = 1;
	char c = getchar();
	while(c < '0' || c > '9') {
		if(c == '-') fh = -1;
		c = getchar();
	}
	while(c >= '0' && c <= '9') {
		rv = (rv<<1) + (rv<<3) + c - '0';
		c = getchar();
	}
	return fh * rv;
}
struct Treap{
	struct node{
		int val, l, r, cnt, size, dat;
	}a[MAXN];
	int tot, rot;
	int New(int x) {
		a[++tot].val = x;
		a[tot].cnt = a[tot].size = 1;
		a[tot].dat = rand();
		return tot;
	}
	void PushUp(int rt) {
		a[rt].size = a[a[rt].l].size + a[a[rt].r].size + a[rt].cnt;
	}
	void Build() {
		New(-inf); New(inf);
		rot = 1;
		a[1].r = 2;
		PushUp(rot);
	}
	int GetRank(int rt, int x) {
		if(!rt) return 1;
		if(x == a[rt].val) return a[a[rt].l].size + 1;
		if(x < a[rt].val) return GetRank(a[rt].l, x);
		else return GetRank(a[rt].r, x) + a[a[rt].l].size + a[rt].cnt;
	}
	int GetVal(int rt, int x) {
		if(!rt) return inf;
		if(a[a[rt].l].size >= x) return GetVal(a[rt].l, x);
		if(a[a[rt].l].size + a[rt].cnt >= x) return a[rt].val;
		else return GetVal(a[rt].r, x - a[a[rt].l].size - a[rt].cnt);
	}
	void zig(int &rt) {
		int nxt = a[rt].l;
		a[rt].l = a[nxt].r; a[nxt].r = rt; rt = nxt;
		PushUp(a[rt].r);
		PushUp(rt);
	}
	void zag(int & rt) {
		int nxt = a[rt].r;
		a[rt].r = a[nxt].l ; a[nxt].l = rt; rt = nxt;
		PushUp(a[rt].l);
		PushUp(rt);
	}
	void Insert(int &rt, int x) {
		if(!rt) {
			rt = New(x);
			return ;
		}
		if(x == a[rt].val) {
			a[rt].cnt ++;
			PushUp(rt);
			return ;
		}
		if(x < a[rt].val) {
			Insert(a[rt].l, x);
			if(a[a[rt].l].dat > a[rt].dat) zig(rt);
		}
		if(x > a[rt].val) {
			Insert(a[rt].r, x);
			if(a[a[rt].r].dat > a[rt].dat) zag(rt);
		}
		PushUp(rt);
	}
	int GetPre(int x) {
		int ans = 1;
		int rt = rot;
		while(rt) {
			if(x == a[rt].val) {
				if(a[rt].l) {
					rt = a[rt].l;
					while(a[rt].r) rt = a[rt].r;
					ans = rt;
				}
				break;
			}
			if(x > a[rt].val && a[rt].val >a[ans].val) ans = rt;
			rt = x > a[rt].val ? a[rt].r : a[rt].l;
		}
		return a[ans].val;
	}
	int GetNext(int x) {
		int ans = 2;
		int rt = rot;
		while(rt) {
			if(x == a[rt].val) {
				if(a[rt].r) {
					rt = a[rt].r;
					while(a[rt].l) rt = a[rt].l;
					ans = rt;
				}
				break;
			}
			if(x < a[rt].val && a[rt].val < a[ans].val) ans = rt;
			rt = x > a[rt].val ? a[rt].r : a[rt].l;
		}
		return a[ans].val;
	}
	void Remove(int &rt, int x) {
		if(!rt) return ;
		if(x == a[rt].val) {
			if(a[rt].cnt > 1) {
				a[rt].cnt--;
				PushUp(rt);
				return ;
			}
			if(a[rt].r || a[rt].l) {
				if(!a[rt].r || a[a[rt].l].dat > a[a[rt].r].dat) {
					zig(rt);
					Remove(a[rt].r, x);
				}else{
					zag(rt);
					Remove(a[rt].l, x);
				}
				PushUp(rt);
			}else rt = 0;
			return;
		}
		x > a[rt].val ? Remove(a[rt].r, x) : Remove(a[rt].l, x);
		PushUp(rt);
	}
	void work(int opt, int num) {
		switch (opt) {
			case 1 : Insert(rot, num);break;
			case 2 : Remove(rot, num);break;
			case 3 : printf("%d\n", GetRank(rot, num) - 1);break;
			case 4 : printf("%d\n", GetVal(rot, num + 1));break;
			case 5 : printf("%d\n", GetPre(num));break;
			case 6 : printf("%d\n", GetNext(num));break;
		}
	}
}bst;
int n ;
int main() {
	n = init();
	bst.Build();
	for(int i = 1; i <= n ; i++) {
		int opt = init(), num = init();
		bst.work(opt, num);
	}
	return 0;
}
posted @ 2018-03-08 08:49  Mr_Wolfram  阅读(213)  评论(2编辑  收藏  举报