splay版

指针是个好东西

不过就是得判空

还有别忘传引用(其实应该都传引用)

#include<cstdio>
#include<algorithm>
#include<iostream>
using namespace std;
int inf=0x7fffffff;
struct node* nil;
struct node
{
	int num;
	int val;
	int size;
	node* ch[2];
	node (int v) :val(v)
	{
		size=1;
		num=1;
		ch[0]=ch[1]=nil;
	}
	void sum()
	{
		size=num;
		if(ch[0]!=nil)
			size+=ch[0]->size;
		if(ch[1]!=nil)
			size+=ch[1]->size;
		return ;
	}
	int cmp(int v)
	{
		if(val==v)
			return -1;
		return (val>v ? 0 : 1);
	}
	int cmpkth(int k)
	{
		int s=( ch[0]==nil ? 0 : ch[0]->size );
		if(k>s&&k<=s+num)
			return -1;
		if(k<=s)
			return 0;
		else
			return 1;
	}
};
node* root;
void visit(node* x)
{
	if(x==nil)
		return ;
	visit(x->ch[0]);
	printf("%d ",x->val);
	visit(x->ch[1]);
	return ;
}
void rotato(node* &x,int base)
{
	node* k=x->ch[base^1];
	x->ch[base^1]=k->ch[base];
	k->ch[base]=x;
	x->sum();
	k->sum();
	x=k;
}
void splay(node* &x,int v)
{
	int d=x->cmp(v);
	if(d!=-1&&x->ch[d]!=nil)
	{
		int d2=x->ch[d]->cmp(v);
		if(d2!=-1&&x->ch[d]->ch[d2]!=nil)
		{
			splay(x->ch[d]->ch[d2],v);
			if(d==d2)
				rotato(x,d2^1),rotato(x,d^1);
			else
				rotato(x->ch[d],d2^1),rotato(x,d^1);
		}
		else
			rotato(x,d^1);
	}
}
void splaykth(node* &x,int k)
{
	int d=x->cmpkth(k);
	if(d!=-1)
	{
		if(d==1)
			k=k-x->ch[0]->size-x->num;
		int d2=x->ch[d]->cmpkth(k);
		if(d2!=-1)
		{
			int k2=(d2==1 ? k-x->ch[d]->ch[0]->size-x->ch[d]->num : k);
			splaykth(x->ch[d]->ch[d2],k2);
			if(d==d2)
				rotato(x,d2^1),rotato(x,d^1);
			else
				rotato(x->ch[d],d2^1),rotato(x,d^1);
		}
		else
			rotato(x,d^1);
	}
	return ;
}
void pre(node* x,int val,int &ans)
{
	if(x==nil)
		return ;
	if(x->val<val)
	{
		if(x->val>ans)
			ans=x->val;
		if(x->ch[1]!=nil)
			pre(x->ch[1],val,ans);
	}
	else
		if(x->val>=val&&x->ch[0]!=nil)
			pre(x->ch[0],val,ans);
}
void nxt(node* x,int val,int &ans)
{
	if(x==nil)
		return ;
	if(x->val>val)
	{
		if(x->val<ans)
			ans=x->val;
		if(x->ch[0]!=nil)
			nxt(x->ch[0],val,ans);
	}
	else
		if(x->val<=val&&x->ch[1]!=nil)
			nxt(x->ch[1],val,ans);
}
int find(node* &x,int val)
{
	splay(x,val);
	return x->ch[0]->size+1;
}
int kth(node* &x,int k)
{
	splaykth(x,k);
	return x->val;
}
node *spilt(node* &x,int val)
{
	if(x==nil)
		return nil;
	splay(x,val);
	node* t1;
	node* t2;
	if(x->val<=val)
		t1=x,t2=x->ch[1],t1->ch[1]=nil;
	else
		t2=x,t1=x->ch[0],t2->ch[0]=nil;
	x->sum();
	x=t1;
	return t2;
}
void merge(node* &t1,node* &t2)
{
	if(t1==nil)
		swap(t1,t2);
	splay(t1,inf);
	t1->ch[1]=t2;
	t2=nil;
	t1->sum();
}
void insert(node* &x,int val)
{
	//visit(x);printf("\n");
	node* t2=spilt(x,val);
	//visit(x);printf("\n");
	if(x!=nil&&x->val==val)
	{
		x->num+=1;
		x->sum();
	}
	else
	{
		node* nw=new node(val);
		merge(x,nw);
	}
	merge(x,t2);
	//visit(root);printf("\n");
}
void erase(node* &x,int val)
{
	node* t2=spilt(x,val);
	x->num-=1;
	if(x->num==0)
	{
		node* t3=x;
		x=x->ch[0];
		delete t3;
	}
	merge(x,t2);
}
int read()
{
	int s=0,f=1;
	char in=getchar();
	while(in<'0'||in>'9')
	{
		if(in=='-')
			f=-1;
		in=getchar();
	}
	while(in>='0'&&in<='9')
	{
		s=(s<<1)+(s<<3)+in-'0';
		in=getchar();
	}
	return s*f;
}
int main()
{
	int n=read();
	int a,b;
	int ans;
	nil=new node(0);
	root=nil->ch[0]=nil->ch[1]=nil;
	nil->size=nil->num=0;
	for(int i=1;i<=n;i++)
	{
		a=read();
		b=read();
		switch(a)
		{
			case 1: insert(root,b);break;
			case 2:	erase(root,b);break;
			case 3: printf("%d\n",find(root,b));break;
			case 4: printf("%d\n",kth(root,b));break;
			case 5:	insert(root,b);ans=-0x7fffffff;pre(root,b,ans);printf("%d\n",ans);erase(root,b);break;
			case 6: insert(root,b);ans=0x7fffffff;nxt(root,b,ans);printf("%d\n",ans);erase(root,b);break;
			default: break;
		}
	}
	return 0;
}
posted @ 2018-04-25 21:29  Lance1ot  阅读(148)  评论(0编辑  收藏  举报