SBT指针板子

#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;
const int INF = 1e9;

struct node{
	node *ch[2];
	int val,size;
	node(int v=0,int siz=0):val(v),size(siz){
		ch[0]=ch[1]=NULL;
	}
	void pushup(){size=ch[0]->size+ch[1]->size+1;}
	int rk(){return ch[0]->size+1;}
}*root,*null,poor[maxn],*tail,*st[maxn];
int top;

void init(){
	tail=poor;top=0;
	root=null=new node();
	null->ch[0]=null->ch[1]=null;
}
node *newnode(int val){
	node *o=new(top?st[top--]:tail++) node(val,1);
	o->ch[0]=o->ch[1]=null;
	return o; 
}
void rotate(node *&x,int k){
	node *w=x->ch[!k];
	x->ch[!k]=w->ch[k];
	w->ch[k]=x;
	w->size=x->size;
	x->pushup();
	x=w;
}
void maintain(node *&o,int k){
	if(o->ch[k]->ch[k]->size>o->ch[!k]->size)rotate(o,!k);
	else if(o->ch[k]->ch[!k]->size>o->ch[!k]->size)rotate(o->ch[k],k),rotate(o,!k);
	else return;
	maintain(o->ch[0],0);maintain(o->ch[1],1);
	maintain(o,0);maintain(o,1);
}
void insert(node *&o,int val){
	if(o==null){
		o=newnode(val);
		return;
	}
	o->size++;
	insert(o->ch[o->val<val],val);
	maintain(o,o->val<val);
}
void del(node *&o,int val){
	if(o->val!=val){
		del(o->ch[o->val<val],val);
		o->pushup();
		return ; 
	}
	o->size--;
	node *p=o;
	if(o->ch[0]==null)o=o->ch[1],st[++top]=p;
	else if(o->ch[1]==null)o=o->ch[0],st[++top]=p;
	else{
		p=o->ch[1];
		while(p->ch[0]!=null)p=p->ch[0];
		o->val=p->val;
		del(o->ch[1],p->val);
	}
}
int rank(int val){
	node *o=root;
	int res=0;
	while(o!=null){
		if(o->val<val)res+=o->rk(),o=o->ch[1];
		else o=o->ch[0];
	}
	return res+1;
}
int val(int rank){//
	node *o=root;
	while(o!=null){
		if(o->ch[0]->size>=rank)o=o->ch[0];
		else if(o->rk()>=rank)return o->val;
		else rank-=o->rk(),o=o->ch[1]; 
	}
}
int pre(int val){
	node *o=root;
	int res;
	while(o!=null){
		if(o->val<val){
			res=o->val;
			o=o->ch[1];
		}else o=o->ch[0];
	}
	return res;
}
int nex(int val){
	node *o=root;
	int res;
	while(o!=null){
		if(o->val>val){
			res=o->val;
			o=o->ch[0];
		}else o=o->ch[1];
	}
	return res;
}

int n,opt,x;

int main(){
	init(); 
	scanf("%d",&n);
	for(int i=1;i<=n;i++){
		scanf("%d%d",&opt,&x);
		if(opt==1)insert(root,x);
		if(opt==2)del(root,x);
		if(opt==3)printf("%d\n",rank(x));
		if(opt==4)printf("%d\n",val(x));
		if(opt==5)printf("%d\n",pre(x));
		if(opt==6)printf("%d\n",nex(x));
	}
	return 0;
} 
posted @ 2021-02-17 10:38  _Famiglistimo  阅读(82)  评论(0编辑  收藏  举报