【伸展树】[CQBZOJ2803]普通平衡树splay模板
贴代码
#include<cstdio>
#include<algorithm>
using namespace std;
#define MAXN 500000
int n,mi,ans;
struct node{
int val,cnt,size;
node *fa,*ch[2];
}splay_tree[MAXN+10],*tcnt=splay_tree,*root;
void Read(int &x){
char c;
bool f=0;
while(c=getchar(),c!=EOF){
if(c=='-')
f=1;
if(c>='0'&&c<='9'){
x=c-'0';
while(c=getchar(),c>='0'&&c<='9')
x=x*10+c-'0';
ungetc(c,stdin);
if(f)
x=-x;
return;
}
}
}
inline int Get_size(node *p){
return p?p->size:0;
}
inline void update(node *p){
p->size=p->cnt+Get_size(p->ch[0])+Get_size(p->ch[1]);
}
void Rotate(node *a,int d){
node *b=a->fa;
b->ch[!d]=a->ch[d];
a->fa=b->fa;
if(a->ch[d])
a->ch[d]->fa=b;
if(b->fa)
b->fa->ch[b==b->fa->ch[1]]=a;
b->fa=a;
a->ch[d]=b;
update(b);
}
void splay(node *x,node *rt){
node* y,*z;
while(x->fa!=rt){
y=x->fa;
z=y->fa;
if(z==rt){
if(x==y->ch[0])
Rotate(x,1);
else
Rotate(x,0);
}
else{
if(y==z->ch[0])
if(x==y->ch[0]){ //ZIG-ZIG
Rotate(y,1);
Rotate(x,1);
}
else{ //ZAG-ZIG
Rotate(x,0);
Rotate(x,1);
}
else{
if(x==y->ch[1]){ //ZAG-ZAG
Rotate(y,0);
Rotate(x,0);
}
else{ //ZIG-ZAG
Rotate(x,1);
Rotate(x,0);
}
}
}
}
update(x);
if(!rt)
root=x;
}
inline void clear(node *p){
p->ch[0]=p->ch[1]=NULL;
}
void insert(int val){
node *p=root,*fa=NULL;
while(p){
fa=p;
if(p->val<val)
p=p->ch[1];
else if(p->val>val)
p=p->ch[0];
else{
p->cnt++;
splay(p,NULL);
return;
}
}
p=++tcnt;
clear(p);
p->cnt=1;
p->fa=fa;
if(fa){
if(val<fa->val)
fa->ch[0]=p;
else
fa->ch[1]=p;
}
else
root=p;
p->val=val;
splay(p,NULL);
}
node *find(node *x,int d){
while(x&&x->ch[d])
x=x->ch[d];
return x;
}
void del(int val){
node *p=root;
while(p){
if(val<p->val)
p=p->ch[0];
else if(val>p->val)
p=p->ch[1];
else{
if(p->cnt>1){
p->cnt--;
splay(p,NULL);
return;
}
else{
splay(p,NULL);
node *ft=find(p->ch[0],1),*bk=find(p->ch[1],0);
if(!ft&&!bk)
root=NULL;
else if(!ft)
root=root->ch[1],root->fa=NULL;
else if(!bk)
root=root->ch[0],root->fa=NULL;
else{
splay(ft,NULL);
splay(bk,root);
bk->ch[0]=NULL;
bk->size--;
ft->size--;
}
}
return;
}
}
}
int find_pos(int val){
node *p=root;
int ret=0;
while(p){
if(val<p->val)
p=p->ch[0];
else if(val>p->val){
ret+=Get_size(p->ch[0])+p->cnt;
p=p->ch[1];
}
else
return ret+Get_size(p->ch[0])+1;
}
splay(p,NULL);
}
int pos_find(int pos){
node *p=root;
while(p){
if(!pos){
p=find(p,0);
return p->val;
}
if(pos<Get_size(p->ch[0]))
p=p->ch[0];
else if(pos>=Get_size(p->ch[0])+p->cnt)
pos-=Get_size(p->ch[0])+p->cnt,p=p->ch[1];
else
return p->val;
}
splay(p,NULL);
}
int find_bk(int val){
insert(val);
int ret=find(root->ch[1],0)->val;
del(val);
return ret;
}
int find_ft(int val){
insert(val);
int ret=find(root->ch[0],1)->val;
del(val);
return ret;
}
int main()
{
Read(n);
int a,b;
while(n--){
Read(a),Read(b);
if(a==1)
insert(b);
else if(a==2)
del(b);
else if(a==3)
printf("%d\n",find_pos(b));
else if(a==4)
printf("%d\n",pos_find(b-1));
else if(a==5)
printf("%d\n",find_ft(b));
else
printf("%d\n",find_bk(b));
}
}