Treap模板(旋转)

目前见过作为名次树实现的用法,其他暂时没见过←_←

#include<bits/stdc++.h>
#include<stdio.h>
#include<algorithm>
#include<queue>
#include<string.h>
#include<iostream>
#include<math.h>
#include<set>
#include<map>
#include<vector>
#include<iomanip>
using namespace std;
#define ll long long
#define ull unsigned long long
#define pb push_back
#define FOR(a) for(int i=1;i<=a;i++)
const int inf=0x3f3f3f3f;
const int maxn=2e6+7; 
const long long mod=1e9+7;

struct NODE{
	NODE *ch[2];
	int v,r,s,w;//数据,名次,节点大小,数据出现次数
	NODE(int v):v(v){
		ch[0]=ch[1]=NULL;
		r=rand();s=w=1;
	}
	bool operator < (const NODE &rhs)const{return r<rhs.r;}
	int cmp(int x)const{
		if(x==v)return -1;
		return x<v?0:1;		//左边0右边1
	}
	int cmp1(int x)const{	//第k大查询的比较
		int sz=w;
		if(ch[0])sz+=ch[0]->s;
		if(sz-w+1<=x && x<=sz)return -1;//找到自身
		if(x<=sz-w)return 0;
		return 1;
	}
	void maintain(){
		s=w;if(ch[0])s+=ch[0]->s;if(ch[1])s+=ch[1]->s;
	}
}*root;
void rotate(NODE* &o,int d){	//0左旋
	NODE *k=o->ch[d^1];
	o->ch[d^1]=k->ch[d];
	k->ch[d]=o;
	o->maintain();k->maintain();
	o=k;
}
void insert(NODE* &o,int v){
	if(!o){							//空节点
		o=new NODE(v);
		return;
	}else{
		int d=o->cmp(v);
		if(d==-1)o->w++;
		else{
			insert(o->ch[d],v);		//先插到叶子,再往上旋
			if(o->ch[d]>o)rotate(o,d^1);//左儿子大就右旋~
		}
	}
	o->maintain();
}
void del(NODE* &o,int v){
	int d=o->cmp(v);
	if(d==-1){
		if(o->w > 1)o->w--;
		else if(o->ch[0]&&o->ch[1]){
			int d2=0;
			if(o->ch[0]>o->ch[1])d2=1;
			rotate(o,d2);
			del(o->ch[d2],v);
		}else{
			if(o->ch[0])o=o->ch[0];
			else o=o->ch[1];
		}
	}else del(o->ch[d],v);
	if(o)o->maintain();
}
void remove(NODE* &o){
	if(!o)return;
	if(o->ch[0])remove(o->ch[0]);
	if(o->ch[1])remove(o->ch[1]);
	delete o;
	o=NULL;
}
int find(NODE* &o,int x){
	if(o==NULL)return 0;
	int d=o->cmp(x);
	if(d==-1)return o->w;
	return find(o->ch[d],x);
}
int kth(NODE* o,int k){		//按尺寸进子树查找
	int d=o->cmp1(k);
	int sz=o->w;
	if(o->ch[0])sz+=o->ch[0]->s;
	if(d==-1)return o->v;
	if(d==0)return kth(o->ch[0],k);
	return kth(o->ch[1],k-sz);
}
int query(NODE *o,int x){	//求排名,前面有多少个
	if(!o)return 0;
	int d=o->cmp(x);
	int sz=o->w;
	if(o->ch[0])sz+=o->ch[0]->s;
	if(d==-1)return sz-o->w;
	else if(d==0)return query(o->ch[0],x);
	else return query(o->ch[1],x)+sz;
}
/*
//求前驱:
int sz=query(root,x);	
printf("%d\n",kth(root,sz));

//求后继:
int sz=query(root,x);
sz+=find(root,x)+1;
printf("%d\n",kth(root,sz));

int main(){
	insert(root,1);
	insert(root,2);
	insert(root,3);
	cout<<query(root,3)<<endl;
}
*/

int main(){
	int n;scanf("%d",&n);
	int op,x;
	while(n--){
		scanf("%d%d",&op,&x);
		if(op==1){insert(root,x);}
		else if(op==2){
			del(root,x);
		}else if(op==3){
			printf("%d\n",query(root,x)+1);
		}else if(op==4){
			printf("%d\n",kth(root,x));
		}else if(op==5){
			int sz=query(root,x);
			printf("%d\n",kth(root,sz));
		}else if(op==6){
			int sz=query(root,x);
			sz+=find(root,x)+1;
			printf("%d\n",kth(root,sz));
		}
	}
}

哪天有空细化一下分类吧。。

--------------------------------------------------------------

增加了一下随机数优化的版本:

#include<bits/stdc++.h>
#include<stdio.h>
#include<algorithm>
#include<queue>
#include<string.h>
#include<iostream>
#include<math.h>
#include<set>
#include<map>
#include<vector>
#include<iomanip>
using namespace std;
#define ll long long
#define ull unsigned long long
#define pb push_back
#define FOR(a) for(int i=1;i<=a;i++)
const int inf=0x3f3f3f3f;
const int maxn=2e6+7; 
const long long mod=1e9+7;

inline int Random(){
	static int seed=703;
	return seed=int(seed*48271ll%2147483647);
}
struct NODE{
	NODE *ch[2];
	int v,r,s,w;//数据,名次,节点大小,数据出现次数
	NODE(int v):v(v){
		ch[0]=ch[1]=NULL;
		r=Random();s=w=1;
	}
	bool operator < (const NODE &rhs)const{return r<rhs.r;}
	inline int cmp(int x)const{
		if(x==v)return -1;
		return x<v?0:1;		//左边0右边1
	}
	inline int cmp1(int x)const{	//第k大查询的比较
		int sz=w;
		if(ch[0])sz+=ch[0]->s;
		if(sz-w+1<=x && x<=sz)return -1;//找到自身
		if(x<=sz-w)return 0;
		return 1;
	}
	inline void maintain(){
		s=w;if(ch[0])s+=ch[0]->s;if(ch[1])s+=ch[1]->s;
	}
}*root;
inline void rotate(NODE* &o,int d){	//0左旋
	NODE *k=o->ch[d^1];
	o->ch[d^1]=k->ch[d];
	k->ch[d]=o;
	o->maintain();k->maintain();
	o=k;
}
inline void insert(NODE* &o,int v){
	if(!o){							//空节点
		o=new NODE(v);
		return;
	}else{
		int d=o->cmp(v);
		if(d==-1)o->w++;
		else{
			insert(o->ch[d],v);		//先插到叶子,再往上旋
			if(o->ch[d]>o)rotate(o,d^1);//左儿子大就右旋~
		}
	}
	o->maintain();
}
inline void del(NODE* &o,int v){
	int d=o->cmp(v);
	if(d==-1){
		if(o->w > 1)o->w--;
		else if(o->ch[0]&&o->ch[1]){
			int d2=0;
			if(o->ch[0]>o->ch[1])d2=1;
			rotate(o,d2);
			del(o->ch[d2],v);
		}else{
			if(o->ch[0])o=o->ch[0];
			else o=o->ch[1];
		}
	}else del(o->ch[d],v);
	if(o)o->maintain();
}
inline void remove(NODE* &o){
	if(!o)return;
	if(o->ch[0])remove(o->ch[0]);
	if(o->ch[1])remove(o->ch[1]);
	delete o;
	o=NULL;
}
inline int find(NODE* &o,int x){
	if(o==NULL)return 0;
	int d=o->cmp(x);
	if(d==-1)return o->w;
	return find(o->ch[d],x);
}
inline int kth(NODE* o,int k){		//按尺寸进子树查找
	int d=o->cmp1(k);
	int sz=o->w;
	if(o->ch[0])sz+=o->ch[0]->s;
	if(d==-1)return o->v;
	if(d==0)return kth(o->ch[0],k);
	return kth(o->ch[1],k-sz);
}
inline int query(NODE *o,int x){	//求排名,前面有多少个
	if(!o)return 0;
	int d=o->cmp(x);
	int sz=o->w;
	if(o->ch[0])sz+=o->ch[0]->s;
	if(d==-1)return sz-o->w;
	else if(d==0)return query(o->ch[0],x);
	else return query(o->ch[1],x)+sz;
}
/*
//求前驱:
int sz=query(root,x);	
printf("%d\n",kth(root,sz));

//求后继:
int sz=query(root,x);
sz+=find(root,x)+1;
printf("%d\n",kth(root,sz));

int main(){
	insert(root,1);
	insert(root,2);
	insert(root,3);
	cout<<query(root,3)<<endl;
}
*/

int main(){
	int n;scanf("%d",&n);
	int op,x;
	while(n--){
		scanf("%d%d",&op,&x);
		if(op==1){insert(root,x);}
		else if(op==2){
			del(root,x);
		}else if(op==3){
			printf("%d\n",query(root,x)+1);
		}else if(op==4){
			printf("%d\n",kth(root,x));
		}else if(op==5){
			int sz=query(root,x);
			printf("%d\n",kth(root,sz));
		}else if(op==6){
			int sz=query(root,x);
			sz+=find(root,x)+1;
			printf("%d\n",kth(root,sz));
		}
	}
}

-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

数组版,写起来舒服多了,但多次删除会存在爆空间的可能

#include<bits/stdc++.h>
#include<stdio.h>
#include<algorithm>
#include<queue>
#include<string.h>
#include<iostream>
#include<math.h>
#include<set>
#include<map>
#include<vector>
#include<iomanip>
using namespace std;
#define ll long long
#define ull unsigned long long
#define pb push_back
#define FOR(a) for(int i=1;i<=a;i++)
const int inf=0x3f3f3f3f;
const int maxn=1e5+7; 
const long long mod=1e9+7;

int nodecnt,root;
int val[maxn],happen[maxn],l[maxn],r[maxn],rnd[maxn],size[maxn];

void maintain(int x){size[x]=size[l[x]]+size[r[x]]+happen[x];}
void rturn(int &k){int t=l[k];l[k]=r[t];r[t]=k;maintain(k);maintain(t);k=t;}
void lturn(int &k){int t=r[k];r[k]=l[t];l[t]=k;maintain(k);maintain(t);k=t;}
void insert1(int &x,int v){	//按大小
	if(!x){
		x=++nodecnt;
		l[x]=r[x]=0;
		val[x]=v;happen[x]=size[x]=1;rnd[x]=rand();return;
	}
	size[x]++;
	if(val[x]==v)happen[x]++;
	else if(v<val[x]){
		insert1(l[x],v);
		if(rnd[l[x]]<rnd[x])rturn(x);
	}else{
		insert1(r[x],v);
		if(rnd[r[x]]<rnd[x])lturn(x);
	}
}
void insert2(int &x,int v,int pos){	//按位置,0号位最前
	if(!x){
		x=++nodecnt;
		val[x]=v;
		rnd[x]=rand();size[x]=happen[x]=1;l[x]=r[x]=0;
		return;
	}
	size[x]++;
	if(size[l[x]]<pos){
		insert2(r[x],v,pos-size[l[x]]-1);
		if(rnd[r[x]]<rnd[x])lturn(x);
	}else{
		insert2(l[x],v,pos);
		if(rnd[l[x]]<rnd[x])rturn(x);
	}
}
void del(int &x,int v){
	if(!x)return;
	if(val[x]==v){
		if(happen[x]>1){happen[x]--;size[x]--;return;}
		if(!l[x]||!r[x])x=l[x]+r[x];
		else if(rnd[l[x]]<rnd[r[x]]){rturn(x);del(x,v);}
		else{lturn(x);del(x,v);}
	}else{
		size[x]--;
		if(v<val[x])del(l[x],v);
		else del(r[x],v);
	}
}
int rnk(int x,int v){
	if(!x)return 0;
	if(val[x]==v)return size[l[x]]+1;
	else if(v<val[x])return rnk(l[x],v);
	else return size[l[x]]+happen[x]+rnk(r[x],v);
}
int kth(int x,int k){
	if(!x)return 0;
	if(k<=size[l[x]])return kth(l[x],k);
	else if(k>size[l[x]]+happen[x])
		return kth(r[x],k-size[l[x]]-happen[x]);
	else return val[x];
}
int ans;
void pre(int x,int v){
	if(!x)return;
	if(v>val[x])ans=x,pre(r[x],v);
	else pre(l[x],v);
}
void suf(int x,int v){
	if(!x)return;
	if(v<val[x])ans=x,suf(l[x],v);
	else suf(r[x],v);
}

int main(){
	int n;scanf("%d",&n);
	int op,num;
	FOR(n){
		scanf("%d%d",&op,&num);
		if(op==1){
			insert1(root,num);
		}else if(op==2){
			del(root,num);
		}else if(op==3){
			printf("%d\n",rnk(root,num));
		}else if(op==4){
			printf("%d\n",kth(root,num));
		}else if(op==5){
			ans=0;
			pre(root,num);
			printf("%d\n",val[ans]);
		}else{
			ans=0;
			suf(root,num);
			printf("%d\n",val[ans]);
		}
	}	
}


posted @ 2017-08-28 23:51  Drenight  阅读(152)  评论(0编辑  收藏  举报