平衡树板子

P3369 普通平衡树为例。

有旋Treap

#include <bits/stdc++.h>
using namespace std;
const int maxn=100010,inf=2147483645;
int q,opt,tot,root,x;
struct treap{
	int l,r,dat,val,siz,cnt;
}tr[maxn*5];
inline int read(){
	int s=0,w=1;char ch=getchar();
	while(ch<'0'||ch>'9'){
		if(ch=='-') w=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9'){
		s=(s<<1)+(s<<3)+(ch^48);
		ch=getchar(); 
	}
	return s*w;
} 
inline int New(int val){
	tr[++tot].val=val;
	tr[tot].dat=rand();
	tr[tot].cnt=tr[tot].siz=1;
	//新建点 
	return tot;
}
inline void update(int p){
	tr[p].siz=tr[tr[p].l].siz+tr[tr[p].r].siz+tr[p].cnt;
}
inline void build(){
	root=New(-inf);tr[root].r=New(inf);
	update(root);
	//建树 更新信息 
}
int getrank(int p,int val){
	if(!p) return 0;
	if(val==tr[p].val) return tr[tr[p].l].siz+1;
	//查询到该值,返回左子树的大小+1(即排名 
	if(val<tr[p].val) return getrank(tr[p].l,val);
	return getrank(tr[p].r,val)+tr[tr[p].l].siz+tr[p].cnt;
}
int getval(int p,int rank){
	if(!p) return inf;
	if(rank<=tr[tr[p].l].siz) return getval(tr[p].l,rank);
	//必须写在前面 
	if(rank<=tr[tr[p].l].siz+tr[p].cnt) return tr[p].val;
	//若排名在这个区间范围内可以直接返回值 
	return getval(tr[p].r,rank-tr[tr[p].l].siz-tr[p].cnt);
}
inline void zig(int &p){
	//注意引用 右旋 
	int q=tr[p].l;
	tr[p].l=tr[q].r,tr[q].r=p,p=q;
	update(tr[p].r),update(p);
}
inline void zag(int &p){
	//注意引用 左旋 
	int q=tr[p].r;
	tr[p].r=tr[q].l,tr[q].l=p,p=q;
	update(tr[p].l),update(p);
}
inline void insert(int &p,int val){
	//注意引用 插入
	if(!p){
		p=New(val);
		return;
	}
	if(val==tr[p].val){
		tr[p].cnt++;
		update(p);
		return;
	}
	if(val<tr[p].val){
		insert(tr[p].l,val);
		if(tr[p].dat<tr[tr[p].l].dat) zig(p); 
		//在这里右旋 不满足堆性质 
	}else{
		insert(tr[p].r,val);
		if(tr[p].dat<tr[tr[p].r].dat) zag(p); 
		//在这里左旋 
	}
	update(p);
	//注意更新 
} 
inline int getpre(int val){
	int ans=1,p=root;
	while(p){
		if(val==tr[p].val){
			if(tr[p].l){
				p=tr[p].l;
				while(tr[p].r) p=tr[p].r;
				ans=p;
			}
			break;
		}
		if(val>tr[p].val&&tr[p].val>tr[ans].val) ans=p;
		p=val<tr[p].val?tr[p].l:tr[p].r;
	}
	return tr[ans].val;
}
inline int getnxt(int val){
	int ans=2,p=root;
	while(p){
		if(val==tr[p].val){
			if(tr[p].r){
				p=tr[p].r;
				while(tr[p].l) p=tr[p].l;
				ans=p;
			}
			break;
		}
		if(val<tr[p].val&&tr[p].val<tr[ans].val) ans=p;
		p=val<tr[p].val?tr[p].l:tr[p].r;
	}
	return tr[ans].val;
}
void remove(int &p,int val){
	//注意引用 删除
	if(!p) return;
	if(val==tr[p].val){
		if(tr[p].cnt>1){
			tr[p].cnt--;
			update(p);
			return;
		}
		if(tr[p].l||tr[p].r){
			if(!tr[p].r||tr[tr[p].l].dat>tr[tr[p].r].dat) zig(p),remove(tr[p].r,val);
			else zag(p),remove(tr[p].l,val);
			update(p);
			//注意更新 
		}else p=0;
		return;
	}
	val<tr[p].val?remove(tr[p].l,val):remove(tr[p].r,val);
	update(p);
}
int main(){
	build();q=read();
	while(q--){
		opt=read();x=read();
		if(opt==1) insert(root,x);
		else if(opt==2) remove(root,x);
		else if(opt==3) printf("%d\n",getrank(root,x)-1);
		else if(opt==4) printf("%d\n",getval(root,x+1));
		else if(opt==5) printf("%d\n",getpre(x));
		else printf("%d\n",getnxt(x));
	}
	return 0;
} 

无旋Treap

#include <bits/stdc++.h>
#define lson(x) tr[(x)].l
#define rson(x) tr[(x)].r
using namespace std;
const int maxn=2e5+20;
int q,opt,tot,root;
struct N_treap{
	int l,r,val,cnt,dat,siz;
}tr[maxn*5];
inline void pushup(int p){
	tr[p].siz=tr[p].cnt+tr[lson(p)].siz+tr[rson(p)].siz;
}
inline int New(int val){
	tr[++tot].val=val;
	tr[tot].dat=rand();
	lson(tot)=rson(tot)=0;
	tr[tot].cnt=tr[tot].siz=1;
	return tot;
}
inline int merge(int l,int r){
	if(!l||!r) return l+r;
	if(tr[l].dat>tr[r].dat){
		rson(l)=merge(rson(l),r);
		pushup(l);
		return l;
	}else{
		lson(r)=merge(l,lson(r));
		pushup(r);
		return r;
	}
}
inline void spilt(int p,int val,int &l,int &r){
	if(!p) l=r=0;
	else{
		if(tr[p].val<=val){
			l=p;
			spilt(rson(l),val,rson(l),r);
			pushup(l);
		}else{
			r=p;
			spilt(lson(r),val,l,lson(r));
			pushup(r);
		}
	}
} 
inline void insert(int val){
	int l,r;spilt(root,val,l,r);
	root=merge(merge(l,New(val)),r);
}
inline void remove(int val){
	int l,r,s;spilt(root,val-1,l,r);
	spilt(r,val,r,s);
	r=merge(lson(r),rson(r));
	root=merge(merge(l,r),s);
	return;
}
inline int getrank(int val){
	int l,r;
	spilt(root,val-1,l,r);
	int res=tr[l].siz+1;
	root=merge(l,r);
	return res;
}
inline int kth(int p,int k){
	if(k<=tr[lson(p)].siz) return kth(lson(p),k);
	k-=tr[lson(p)].siz+tr[p].cnt;
	if(k<=0) return p;
	else return kth(rson(p),k);
}
inline int Kth(int k){
	return tr[kth(root,k)].val;
}
inline int pre(int val){
	int l,r;
	spilt(root,val-1,l,r);
	int res=tr[kth(l,tr[l].siz)].val;
	root=merge(l,r);
	return res;
}
inline int suc(int val){
	int l,r;
	spilt(root,val,l,r);
	int res=tr[kth(r,1)].val;
	root=merge(l,r);
	return res;
}
inline int read(){
	int s=0,w=1;char ch=getchar();
	while(ch<'0'||ch>'9'){
		if(ch=='-') w=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9'){
		s=(s<<1)+(s<<3)+(ch^48);
		ch=getchar();
	}
	return s*w;
}
int main(){
	q=read();
	while(q--){
		int dos;
		opt=read();dos=read();
		if(opt==1) insert(dos);
		else if(opt==2) remove(dos);
		else if(opt==3) printf("%d\n",getrank(dos));
		else if(opt==4) printf("%d\n",Kth(dos));
		else if(opt==5) printf("%d\n",pre(dos));
		else printf("%d\n",suc(dos));
	}
	return 0;
}

Splay

#include <bits/stdc++.h>
#define lson(x) tr[(x)].son[0]
#define rson(x) tr[(x)].son[1]
#define fu(x) tr[(x)].fa
using namespace std;
namespace Broken_Eclipse{
	inline int read(){
		int s=0,w=1;char ch;
		while(!isdigit(ch=getchar())) if(ch=='-') w=-1;
		do s=s*10+(ch^48);while(isdigit(ch=getchar()));
		return s*w;
	} 
}
using namespace Broken_Eclipse;
const int maxn=100010;
int root,tot,q,opt,x;
struct splay{
	int son[2],val,cnt,siz,fa;
}tr[maxn*5];
inline void update(int p){
	tr[p].siz=tr[lson(p)].siz+tr[rson(p)].siz+tr[p].cnt;
} 
inline bool get(int p){
	return p==rson(fu(p));
}
inline void clear(int p){
	lson(p)=rson(p)=fu(p)=tr[p].val=tr[p].siz=tr[p].cnt=0;
}
inline void poer(int p){
	int y=fu(p),z=fu(y),chk=get(p);
	tr[y].son[chk]=tr[p].son[chk^1];
	if(tr[p].son[chk^1]) fu(tr[p].son[chk^1])=y;
	tr[p].son[chk^1]=y;
	fu(y)=p;
	fu(p)=z;
	if(z) tr[z].son[y==rson(z)]=p;
	update(p);
	update(y);
	return;
}
inline void splay(int p){
	for(int f=fu(p);f=fu(p),f;poer(p))
		if(fu(f)) poer(get(p)==get(f)?f:p);
	root=p;
}
inline void insert(int val){
	if(!root){
		tr[++tot].val=val;
		tr[tot].cnt++;
		root=tot;
		update(root);
		return;
	}
	int cur=root,f=0;
	while(1){
		if(tr[cur].val==val){
			tr[cur].cnt++;
			update(cur);
			update(f);
			splay(cur);
			break;
		}
		f=cur;
		cur=tr[cur].son[tr[cur].val<val];
		if(!cur){
			tr[++tot].val=val;
			tr[tot].cnt++;
			fu(tot)=f;
			tr[f].son[tr[f].val<val]=tot;
			update(tot);
			update(f);
			splay(tot);
			break;
		} 
	}
}
inline int getrank(int val){
	int res=0,cur=root;
	while(1){
		if(val<tr[cur].val) cur=lson(cur);
		else{
			res+=tr[lson(cur)].siz;
			if(val==tr[cur].val){
				splay(cur);
				return res+1;
			}
			res+=tr[cur].cnt;
			cur=rson(cur);
		}
	}
}
inline int getval(int rank){
	int cur=root;
	while(1){
		if(lson(cur)&&rank<=tr[lson(cur)].siz) cur=lson(cur);
		else{
			rank-=tr[cur].cnt+tr[lson(cur)].siz;
			if(rank<=0){
				splay(cur);
				return tr[cur].val;
			}
			cur=rson(cur);
		}
	}
}
inline int getpre(){
	int cur=lson(root);
	if(!cur) return cur;
	while(rson(cur)) cur=rson(cur);
	splay(cur);
	return cur;
}
inline int getnxt(){
	int cur=rson(root);
	if(!cur) return cur;
	while(lson(cur)) cur=lson(cur);
	splay(cur);
	return cur;
}
inline void del(int val){
	getrank(val);
	if(tr[root].cnt>1){
		tr[root].cnt--;
		update(root);
		return;
	}
	if(!lson(root)&&!rson(root)){
		clear(root);
		root=0;
		return;
	}
	if(!lson(root)){
		int cur=root;
		root=rson(root);
		fu(root)=0;
		clear(cur);
		return;
	}
	if(!rson(root)){
		int cur=root;
		root=lson(root);
		fu(root)=0;
		clear(cur);
		return;
	}
	int cur=root;
	int ze=getpre();
	fu(rson(cur))=ze;
	rson(ze)=rson(cur);
	clear(cur);
	update(root);
	return; 
}
int main(){
	q=read();
	while(q--){
		opt=read();x=read();
		if(opt==1) insert(x);
		else if(opt==2) del(x);
		else if(opt==3) printf("%d\n",getrank(x));
		else if(opt==4) printf("%d\n",getval(x));
		else if(opt==5) insert(x),printf("%d\n",tr[getpre()].val),del(x);
		else insert(x),printf("%d\n",tr[getnxt()].val),del(x);
	}
	return 0;
}
posted @ 2022-06-08 14:05  Broken_Eclipse  阅读(22)  评论(0编辑  收藏  举报

Loading