P3369 【模板】普通平衡树(Treap/SBT)
题目描述
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
-
插入x数
-
删除x数(若有多个相同的数,因只删除一个)
-
查询x数的排名(若有多个相同的数,因输出最小的排名)
-
查询排名为x的数
-
求x的前驱(前驱定义为小于x,且最大的数)
- 求x的后继(后继定义为大于x,且最小的数)
输入输出格式
输入格式:
第一行为n,表示操作的个数,下面n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6)
输出格式:
对于操作3,4,5,6每行输出一个数,表示对应答案
输入输出样例
输入样例#1:
10 1 106465 4 1 1 317721 1 460929 1 644985 1 84185 1 89851 6 81968 1 492737 5 493598
输出样例#1:
106465 84185 492737
说明
时空限制:1000ms,128M
1.n的数据范围:n<=100000
2.每个数的数据范围:[-1e7,1e7]
来源:Tyvj1728 原名:普通平衡树
在此鸣谢
1.03
treap是一棵修改了结点顺序的二叉查找树
通常树内的每个结点x都有一个关键字值key[x],另外,还要为结点分配一个随机数。
假设所有的优先级是不同的,所有的关键字也是不同的。treap的结点排列成让关键字遵循二叉查找树性质,并且优先级遵
循最小堆顺序性质:
1.如果v是u的左孩子,则key[v] < key[u].
2.如果v是u的右孩子,则key[v] > key[u].
3.如果v是u的孩子,则rand[v] > rand[u].
这两个性质的结合就是为什么这种树被称为“treap”的原因,因为它同时具有二叉查找树和heap的特征。
(1.18整编转载自hzwer)
#include<cstdio> #include<cstdlib> #include<ctime> using namespace std; const int N=1e6+10; struct tree{ int l,r;//左右儿子节点编号 int num;//当前节点的数字 int s;//以当前节点为根的子树的节点数 int sum;//当前节点的数字的数量 int rnd;//随机优先级 }tr[N];//下标为节点编号 int n,rt,cnt,t1,t2; void updata(int &k){ int &l=tr[k].l,&r=tr[k].r; tr[k].s=tr[l].s+tr[r].s+tr[k].sum; } void lturn(int &k){ int t=tr[k].r;tr[k].r=tr[t].l;tr[t].l=k; tr[t].s=tr[k].s;updata(k);k=t; } void rturn(int &k){ int t=tr[k].l;tr[k].l=tr[t].r;tr[t].r=k; tr[t].s=tr[k].s;updata(k);k=t; } void insert(int &k,int x){ if(!k){ k=++cnt;tr[k].num=x;tr[k].s=1;tr[k].sum++;tr[k].rnd=rand();return ; } tr[k].s++; int &l=tr[k].l,&r=tr[k].r; if(x<tr[k].num){ insert(l,x); if(tr[l].rnd<tr[k].rnd) rturn(k); } else if(x>tr[k].num){ insert(r,x); if(tr[r].rnd<tr[k].rnd) lturn(k); } else{ tr[k].sum++;return ; } } void del(int &k,int x){ if(!k) return ; int &l=tr[k].l,&r=tr[k].r; if(x==tr[k].num){ if(tr[k].sum>1){ tr[k].sum--;tr[k].s--;return ; } if(l*r==0) k=l+r; else{ if(tr[l].rnd<tr[r].rnd) rturn(k); else lturn(k); del(k,x); } } else{ tr[k].s--; if(x>tr[k].num) del(r,x); else del(l,x); } } int find1(int &k,int x){ if(!k) return 0; int &l=tr[k].l,&r=tr[k].r; if(tr[k].num==x) return tr[l].s+1; if(tr[k].num>x) return find1(l,x); if(tr[k].num<x) return tr[l].s+tr[k].sum+find1(r,x); } int find2(int &k,int x){ if(!k) return 0; int &l=tr[k].l,&r=tr[k].r; if(tr[l].s+1<=x&&tr[l].s+tr[k].sum>=x) return tr[k].num; if(tr[l].s>=x) return find2(l,x); if(tr[l].s+tr[k].sum<x) return find2(r,x-tr[l].s-tr[k].sum); } void pred(int &k,int x){ if(!k) return ; int &l=tr[k].l,&r=tr[k].r; if(tr[k].num<x){ t1=tr[k].num; pred(r,x); } else pred(l,x); } void succ(int &k,int x){ if(!k) return ; int &l=tr[k].l,&r=tr[k].r; if(tr[k].num>x){ t2=tr[k].num; succ(l,x); } else succ(r,x); } int main(){ srand(time(0)); scanf("%d",&n); for(int i=1,opt,x;i<=n;i++){ scanf("%d%d",&opt,&x);t1=t2=0; switch(opt){ case 1:insert(rt,x);break; case 2:del(rt,x);break; case 3:printf("%d\n",find1(rt,x));break; case 4:printf("%d\n",find2(rt,x));break; case 5:pred(rt,x);printf("%d\n",t1);break; case 6:succ(rt,x);printf("%d\n",t2);break; } } return 0; }
1.10
返现GNU系统有个pb_ds库,里面有好多bbt,我用rb-tree过掉了。
不过他的bbt不支持重复元素的出现,难道还要hash一下?
那不就失去了他的优越性了?
山人自有妙计。
//by shenben #include<cstdio> #include<iostream> #include<ext/pb_ds/assoc_container.hpp> #include<ext/pb_ds/tree_policy.hpp> using namespace std; using namespace __gnu_pbds; typedef long long ll; tree<ll,null_mapped_type,less<ll>,rb_tree_tag,tree_order_statistics_node_update> bbt; int n;ll k,ans; inline int read(){ int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } int main(){ freopen("phs.in","r",stdin); freopen("phs.out","w",stdout); n=read(); for(int i=1,opt;i<=n;i++){ opt=read();k=read(); if(opt==1) bbt.insert((k<<20)+i); if(opt==2) bbt.erase(bbt.lower_bound(k<<20)); if(opt==3) printf("%d\n",bbt.order_of_key(k<<20)+1); if(opt==4) ans=*bbt.find_by_order(k-1),printf("%lld\n",ans>>20); if(opt==5) ans=*--bbt.lower_bound(k<<20),printf("%lld\n",ans>>20); // if(opt==6) ans=*bbt.lower_bound(k+1<<20),printf("%lld\n",ans>>20); if(opt==6) ans=*bbt.upper_bound((k<<20)+n),printf("%lld\n",ans>>20); } return 0; }
1.18
看到hzwer‘blog暴力vector,于是就写过了。
貌似插入是O(√n+1)的
#include<cstdio> #include<vector> #include<algorithm> using namespace std; int read(){ int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } int n; vector<int>a; void insert(int x){ a.insert(upper_bound(a.begin(),a.end(),x),x); } void del(int x){ a.erase(lower_bound(a.begin(),a.end(),x)); } int find(int x){ return lower_bound(a.begin(),a.end(),x)-a.begin()+1; } int main(){ a.reserve(200000); n=read(); for(int i=1,opt,x;i<=n;i++){ opt=read();x=read(); switch(opt){ case 1:insert(x);break; case 2:del(x);break; case 3:printf("%d\n",find(x));break; case 4:printf("%d\n",a[x-1]);break; case 5:printf("%d\n",*--lower_bound(a.begin(),a.end(),x));break; case 6:printf("%d\n",*upper_bound(a.begin(),a.end(),x));break; } } return 0; }
2.26 splay版
#include<cstdio> using namespace std; const int N=1e5+5; const int inf=2e9; int n,c[N][2],fa[N],val[N],cnt[N],siz[N],rt,sz; inline int read(){ int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();} return x*f; } void updata(int x){ siz[x]=siz[c[x][0]]+siz[c[x][1]]+cnt[x]; } void rotate(int x,int &k){ int y=fa[x],z=fa[y],l,r; l=(c[y][1]==x);r=l^1; if(y==k) k=x; else c[z][c[z][1]==y]=x; fa[x]=z;fa[y]=x;fa[c[x][r]]=y; c[y][l]=c[x][r];c[x][r]=y; updata(y);updata(x); } void splay(int x,int &k){ while(x!=k){ int y=fa[x],z=fa[y]; if(y!=k){ if((c[y][0]==x)^(c[z][0]==y)) rotate(x,k); else rotate(y,k); } rotate(x,k); } } #define l c[k][0] #define r c[k][1] void Rank(int v){ int k=rt;if(!rt) return ; while(c[k][v>val[k]]&&val[k]!=v) k=c[k][v>val[k]]; splay(k,rt); } int kth(int rk){ rk++;int k=rt; if(siz[k]<rk) return 0; for(;;){ if(siz[l]<rk&&siz[l]+cnt[k]>=rk) return val[k]; if(siz[l]>=rk) k=l; else rk-=siz[l]+cnt[k],k=r; } } void insert(int v){ int k=rt,y=0; while(k&&val[k]!=v) y=k,k=c[k][v>val[k]]; if(k) cnt[k]++; else{ k=++sz;val[k]=v;siz[k]=cnt[k]=1;fa[k]=y; if(y) c[y][v>val[y]]=k; } splay(k,rt); } void erase(int v){ Rank(v);int k; if(cnt[rt]>1){cnt[rt]--;siz[rt]--;return ;} if(!c[rt][0]||!c[rt][1]){ rt=c[rt][0]+c[rt][1]; fa[rt]=0; } else{ k=c[rt][1]; while(l) k=l; siz[k]+=siz[c[rt][0]]; fa[c[rt][0]]=k;l=c[rt][0]; rt=c[rt][1]; fa[rt]=0; splay(k,rt); } } int prev(int v){ Rank(v); if(val[rt]<v) return val[rt]; int k=c[rt][0]; while(r) k=r; return val[k]; } int succ(int v){ Rank(v); if(val[rt]>v) return val[rt]; int k=c[rt][1]; while(l) k=l; return val[k]; } #undef l #undef r int main(){ insert(-inf);insert(inf); n=read(); for(int i=1,opt,x;i<=n;i++){ opt=read();x=read(); if(opt==1) insert(x); if(opt==2) erase(x); if(opt==3) Rank(x),printf("%d\n",siz[c[rt][0]]); if(opt==4) printf("%d\n",kth(x)); if(opt==5) printf("%d\n",prev(x)); if(opt==6) printf("%d\n",succ(x)); } return 0; }