bzoj3224
时间比较:SBT:无旋treap:splay:treap
这道题是看着别人的板子大的,思路不难,但过程里有很多细节要注意
首先是treap
/************************************************************** Problem: 3224 Language: C++ Result: Accepted Time:300 ms Memory:3164 kb ****************************************************************/ #include<cstdio> #include<cctype> #include<algorithm> using namespace std; int n,rt,size; struct data{ int l,r,val,cnt,siz,rnd; }tr[100005]; inline int read(){ char ch=getchar();int k=0,f=1; while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();} while(isdigit(ch)){k=(k<<1)+(k<<3)+ch-'0';ch=getchar();} return k*f; } inline void update(int root){ tr[root].siz=tr[tr[root].l].siz+tr[tr[root].r].siz+tr[root].cnt; } inline int rand(){ static int seed = 2333; return seed = (int)((((seed ^ 998244353) + 19260817ll) * 19890604ll) % 1000000007); } void lturn(int &root){ int t=tr[root].r;tr[root].r=tr[t].l;tr[t].l=root; tr[t].siz=tr[root].siz;update(root);root=t; } void rturn(int &root){ int t=tr[root].l;tr[root].l=tr[t].r;tr[t].r=root; tr[t].siz=tr[root].siz;update(root);root=t; } void _insert(int &root,int now){ if(root==0){ root=++size;tr[root].siz=tr[root].cnt=1;tr[root].val=now;tr[root].rnd=rand();return; } tr[root].siz++; if(tr[root].val==now) tr[root].cnt++; else if(tr[root].val>now){ _insert(tr[root].l,now); if(tr[root].rnd>tr[tr[root].l].rnd) rturn(root); } else{ _insert(tr[root].r,now); if(tr[root].rnd>tr[tr[root].r].rnd) lturn(root); } } void _del(int &root,int now){ if(root==0) return; if(tr[root].val==now){ if(tr[root].cnt>1){tr[root].cnt--;tr[root].siz--;} else if(tr[root].l==0 || tr[root].r==0) root=tr[root].l+tr[root].r; else if(tr[tr[root].r].rnd>tr[tr[root].l].rnd) rturn(root),_del(root,now); else lturn(root),_del(root,now); } else if(tr[root].val>now) tr[root].siz--,_del(tr[root].l,now); else tr[root].siz--,_del(tr[root].r,now); } int query_xpm(int root,int now){ if(root==0) return 0; if(tr[root].val==now) return tr[tr[root].l].siz+1; if(tr[root].val<now) return tr[tr[root].l].siz+tr[root].cnt+query_xpm(tr[root].r,now); else return query_xpm(tr[root].l,now); } int query_pmx(int root,int now){ if(root==0) return 0; if(tr[tr[root].l].siz<now && tr[root].cnt+tr[tr[root].l].siz>=now) return tr[root].val; else if(now<=tr[tr[root].l].siz) return query_pmx(tr[root].l,now); else return query_pmx(tr[root].r,now-tr[tr[root].l].siz-tr[root].cnt); } int query_qq(int root,int now){ if(root==0) return -2e9; if(tr[root].val>=now) return query_qq(tr[root].l,now); else return max(tr[root].val,query_qq(tr[root].r,now)); } int query_hj(int root,int now){ if(root==0) return 2e9; if(tr[root].val<=now) return query_hj(tr[root].r,now); else return min(tr[root].val,query_hj(tr[root].l,now)); } int main(){ n=read(); int flag,x; for(int i=1;i<=n;i++){ flag=read();x=read(); if(flag==1) _insert(rt,x); if(flag==2) _del(rt,x); if(flag==3) printf("%ld\n",query_xpm(rt,x)); if(flag==4) printf("%ld\n",query_pmx(rt,x)); if(flag==5) printf("%ld\n",query_qq(rt,x)); if(flag==6) printf("%ld\n",query_hj(rt,x)); } return 0; }
然后是splay
#include<cstdio> #include<cstring> #include<iostream> using namespace std; #define maxn 1000005 int tr[maxn][2],f[maxn],siz[maxn],cnt[maxn],key[maxn]; int sz,rt; inline void clear(int x){tr[x][0]=tr[x][1]=f[x]=siz[x]=cnt[x]=key[x]=0;} inline void updata(int x){siz[x]=cnt[x]+siz[tr[x][0]]+siz[tr[x][1]];} inline void rotate(int x){ int old=f[x],oldf=f[old],whicx=(tr[f[x]][1]==x); tr[old][whicx]=tr[x][whicx^1];tr[x][whicx^1]=old; f[tr[old][whicx]]=old;f[old]=x;f[x]=oldf; if(oldf)tr[oldf][tr[oldf][1]==old]=x; updata(old);updata(x); } inline void splay(int x){ for(int fa;fa=f[x];rotate(x)) if(f[fa])rotate((tr[f[x]][1]==x)==(tr[f[fa]][1]==fa)?fa:x); rt=x; } inline void insert(int x){ if(rt==0){sz++;tr[sz][0]=tr[sz][1]=f[sz]=0;rt=sz;siz[sz]=cnt[sz]=1;key[sz]=x;return;} int now=rt,fa=0; while(1){ if(x==key[now]){cnt[now]++,updata(now),updata(fa),splay(now);break;} fa=now;now=tr[now][key[now]<x]; if(now==0){ sz++; tr[sz][0]=tr[sz][1]=0,f[sz]=fa;key[sz]=x; siz[sz]=cnt[sz]=1;tr[fa][key[fa]<x]=sz; updata(fa);splay(sz); break; } } } inline int find(int x){ int now=rt,ans=0; while(1){ if(x<key[now])now=tr[now][0]; else{ ans+=(tr[now][0]?siz[tr[now][0]]:0); if(x==key[now]){splay(now);return ans+1;} ans+=cnt[now];now=tr[now][1]; } } } inline int findx(int x){ int now=rt; while(1){ if(tr[now][0]&&x<=siz[tr[now][0]])now=tr[now][0]; else{ int tmp=(tr[now][0]?siz[tr[now][0]]:0)+cnt[now]; if(x<=tmp)return key[now]; x-=tmp;now=tr[now][1]; } } } inline int pre(){int now=tr[rt][0];while(tr[now][1])now=tr[now][1];return now;} inline int next(){int now=tr[rt][1];while(tr[now][0])now=tr[now][0];return now;} inline void del(int x){ int wher=find(x); if(cnt[rt]>1){cnt[rt]--;updata(rt);return;} if(!tr[rt][0]&&!tr[rt][1]){clear(rt),rt=0;return;} if(!tr[rt][0]){int oldrt=rt;rt=tr[rt][1],f[rt]=0,clear(oldrt);return;} else if(!tr[rt][1]){int oldrt=rt;rt=tr[rt][0],f[rt]=0,clear(oldrt);return;} int lefb=pre(),oldrt=rt; splay(lefb); tr[rt][1]=tr[oldrt][1];f[tr[oldrt][1]]=rt; clear(oldrt); updata(rt); } int main(){ int n,opt,x; scanf("%d",&n); for(int i=1;i<=n;i++){ scanf("%d%d",&opt,&x); switch(opt){ case 1:insert(x);break; case 2:del(x);break; case 3:printf("%d\n",find(x));break; case 4:printf("%d\n",findx(x));break; case 5:insert(x);printf("%d\n",key[pre()]);del(x);break; case 6:insert(x);printf("%d\n",key[next()]);del(x);break; } } }
3.无旋treap
/************************************************************** Problem: 3224 Language: C++ Result: Accepted Time:432 ms Memory:2780 kb ****************************************************************/ #include<cstdio> #include<cctype> #include<algorithm> #define mp make_pair<int,int> using namespace std; int cnt,rt,n; typedef pair<int,int>par; struct data{int l,r,key,data,siz;}tr[100002]; void read(int &x){ char ch=getchar();x=0;int f=1; while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();} while(isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();} x*=f; } void updata(int now){ tr[now].siz=tr[tr[now].l].siz+tr[tr[now].r].siz+1; } int rank(int x){ int k=rt,sum=0,tmp=(int)1e9; while(k){ if(tr[k].data==x)tmp=min(tmp,sum+tr[tr[k].l].siz+1); if(tr[k].data<x)sum+=tr[tr[k].l].siz+1,k=tr[k].r; else k=tr[k].l; } return tmp==(int)1e9?sum:tmp; } par split(int x,int k){ if(k==0)return mp(0,x); int ls=tr[x].l,rs=tr[x].r; if(k==tr[ls].siz)return tr[x].l=0,updata(x),mp(ls,x); if(k==tr[ls].siz+1)return tr[x].r=0,updata(x),mp(x,rs); if(k<tr[ls].siz){ par tmp=split(ls,k); return tr[x].l=tmp.second,updata(x),mp(tmp.first,x); } par tmp=split(rs,k-tr[ls].siz-1); return tr[x].r=tmp.first,updata(x),mp(x,tmp.second); } int merge(int a,int b){ if(a==0||b==0)return a+b; if(tr[a].key<tr[b].key)return tr[a].r=merge(tr[a].r,b),updata(a),a; else return tr[b].l=merge(a,tr[b].l),updata(b),b; } void insert(int x){ int k=rank(x); par tmp=split(rt,k); tr[++cnt].data=x; tr[cnt].key=rand(); tr[cnt].siz=1; rt=merge(tmp.first,cnt); rt=merge(rt,tmp.second); } void del(int x){ int k=rank(x); par tmp=split(rt,k); par tmp2=split(tmp.first,k-1); rt=merge(tmp2.first,tmp.second); } int askrank(int x,int k){ while(true){ if(tr[tr[x].l].siz+1==k)return tr[x].data; if(tr[tr[x].l].siz<k){k-=(tr[tr[x].l].siz+1);x=tr[x].r;}else x=tr[x].l; } } int pre(int now,int x){ int ans=-(int)1e9; while(now){ if(tr[now].data<x)ans=max(ans,tr[now].data),now=tr[now].r; else now=tr[now].l; } return ans; } int nex(int now,int x){ int ans=(int)2e9; while(now){ if(tr[now].data>x)ans=min(ans,tr[now].data),now=tr[now].l; else now=tr[now].r; } return ans; } int main(){ read(n); for(int i=1;i<=n;i++){ int opt,x; read(opt);read(x); switch(opt){ case 1:insert(x);break; case 2:del(x);break; case 3:printf("%d\n",rank(x));break; case 4:printf("%d\n",askrank(rt,x));break; case 5:printf("%d\n",pre(rt,x));break; case 6:printf("%d\n",nex(rt,x));break; } } }
4.SBT
/************************************************************** Problem: 3224 Language: C++ Result: Accepted Time:352 ms Memory:2400 kb ****************************************************************/ #include<cstdio> #include<cctype> #include<cstring> #include<algorithm> #define maxn 100001 using namespace std; int n; void read(int &x){ char ch=getchar();x=0;int f=1; while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();} while(isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();} x*=f; } struct SBT{ int rt,cnt; int key[maxn],siz[maxn],ls[maxn],rs[maxn]; void clear(){ rt=0;cnt=0; memset(key,0,sizeof(key)); memset(siz,0,sizeof(siz)); memset(ls,0,sizeof(ls)); memset(rs,0,sizeof(rs)); } void zig(int &p){ int k=rs[p]; rs[p]=ls[k]; ls[k]=p; siz[k]=siz[p]; siz[p]=siz[ls[p]]+siz[rs[p]]+1; p=k; } void zag(int &p){ int k=ls[p]; ls[p]=rs[k]; rs[k]=p; siz[k]=siz[p]; siz[p]=siz[ls[p]]+siz[rs[p]]+1; p=k; } void maintain(int &p,bool flag){ if(!flag){ if(siz[ls[ls[p]]]>siz[rs[p]])zag(p); else{ if(siz[rs[ls[p]]]>siz[rs[p]]){ zig(ls[p]);zag(p); }else return; } }else{ if(siz[rs[rs[p]]]>siz[ls[p]])zig(p); else{ if(siz[ls[rs[p]]]>siz[ls[p]]){ zag(rs[p]); zig(p); }else return; } } maintain(ls[p],false); maintain(rs[p],true); maintain(p,true); maintain(p,false); } void insert(int &p,int x){ if(!p){ p=++cnt;key[p]=x;siz[p]=1;return; } siz[p]++; if(x<key[p])insert(ls[p],x);else insert(rs[p],x); maintain(p,x>=key[p]); } int erase(int &p,int x){ siz[p]--;int tmp; if(x==key[p] ||(x<key[p] && !ls[p])||(x>key[p] && !rs[p])){ tmp=key[p]; if(!ls[p] || !rs[p])p=ls[p]+rs[p]; else key[p]=erase(ls[p],key[p]+1); return tmp; } if(x<key[p])tmp=erase(ls[p],x);else tmp=erase(rs[p],x); return tmp; } int rank(int &p,int x){ if(!p)return 1;int tmp=0; if(x<=key[p])tmp=rank(ls[p],x); else tmp=siz[ls[p]]+1+rank(rs[p],x); return tmp; } int askrank(int &p,int x){ if(x==siz[ls[p]]+1)return key[p]; if(x<=siz[ls[p]])return askrank(ls[p],x); else return askrank(rs[p],x-1-siz[ls[p]]); } int pre(int &p,int x){ if(!p)return x;int tmp; if(x<=key[p])tmp=pre(ls[p],x); else{tmp=pre(rs[p],x);if(tmp==x)tmp=key[p];} return tmp; } int nex(int &p,int x){ if(!p)return x;int tmp; if(x>=key[p])tmp=nex(rs[p],x); else{tmp=nex(ls[p],x);if(tmp==x)tmp=key[p];} return tmp; } }T; int main(){ read(n); T.clear(); int &rt=T.rt=0; while(n--){ int opt,x; read(opt);read(x); switch(opt){ case 1:T.insert(rt,x);break; case 2:T.erase(rt,x);break; case 3:printf("%d\n",T.rank(rt,x));break; case 4:printf("%d\n",T.askrank(rt,x));break; case 5:printf("%d\n",T.pre(rt,x));break; case 6:printf("%d\n",T.nex(rt,x));break; } } }
5.替罪羊树
/************************************************************** Problem: 3224 Language: C++ Result: Accepted Time:260 ms Memory:5512 kb ****************************************************************/ #include<cctype> #include<cstdio> #include<algorithm> #define maxn 200005 #define bz 0.75 using namespace std; int n,cnt,rt,cur[maxn],sum; struct data{int son[2],fa,siz,val;}tr[maxn]; inline int read(){ int x=0,f=1;char ch=getchar(); while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();} while(isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();} return f*x; } inline bool balance(int x){ return (double)tr[x].siz*bz>=(double)tr[tr[x].son[0]].siz && (double) tr[x].siz*bz>=(double)tr[tr[x].son[1]].siz; } inline void recycle(int x){ if(tr[x].son[0])recycle(tr[x].son[0]);cur[++sum]=x;if(tr[x].son[1])recycle(tr[x].son[1]); } inline int build(int l,int r){ if(l>r) return 0; int mid=(l+r)>>1,now=cur[mid]; tr[tr[now].son[0]=build(l,mid-1)].fa=now; tr[tr[now].son[1]=build(mid+1,r)].fa=now; tr[now].siz=tr[tr[now].son[0]].siz+tr[tr[now].son[1]].siz+1; return now; } inline void rebuild(int x){ sum=0;recycle(x); int fa=tr[x].fa,whicx=tr[tr[x].fa].son[1]==x; int cur=build(1,sum); tr[tr[fa].son[whicx]=cur].fa=fa; if(x==rt)rt=cur; } inline void insert(int x){ int now=rt,cur=++cnt; tr[cur].siz=1,tr[cur].val=x; while(1){ tr[now].siz++; bool whicx=(x>=tr[now].val); if(tr[now].son[whicx])now=tr[now].son[whicx]; else{ tr[tr[now].son[whicx]=cur].fa=now;break; } } int flag=0; for(int i=cur;i;i=tr[i].fa)if(!balance(i))flag=i; if(flag)rebuild(flag); } inline int get_num(int x){ int now=rt; while(1){ if(tr[now].val==x) return now; else now=tr[now].son[tr[now].val<x]; } } inline void del(int x){ if(tr[x].son[0] && tr[x].son[1]){ int cur=tr[x].son[0]; while(tr[cur].son[1])cur=tr[cur].son[1]; tr[x].val=tr[cur].val;x=cur; } int whic=(tr[x].son[0])?tr[x].son[0]:tr[x].son[1]; int k=(tr[tr[x].fa].son[1]==x); tr[tr[tr[x].fa].son[k]=whic].fa=tr[x].fa; for(int i=tr[x].fa;i;i=tr[i].fa)tr[i].siz--; if(x==rt)rt=whic; } inline int get_rank(int x){ int now=rt,ans=0; while(now){ if(tr[now].val<x)ans+=tr[tr[now].son[0]].siz+1,now=tr[now].son[1]; else now=tr[now].son[0]; } return ans; } inline int get_kth(int x){ int now=rt; while(1){ if(tr[tr[now].son[0]].siz==x-1)return now; else if(tr[tr[now].son[0]].siz>=x)now=tr[now].son[0]; else x-=tr[tr[now].son[0]].siz+1,now=tr[now].son[1]; } return now; } inline int get_pre(int x){ int now=rt,ans=-2e9; while(now){ if(tr[now].val<x)ans=max(ans,tr[now].val),now=tr[now].son[1]; else now=tr[now].son[0]; } return ans; } inline int get_suc(int x){ int now=rt,ans=2e9; while(now){ if(tr[now].val>x)ans=min(ans,tr[now].val),now=tr[now].son[0]; else now=tr[now].son[1]; } return ans; } int main(){ cnt=2;rt=1; tr[1].val=-2e9,tr[1].siz=2,tr[1].son[1]=2; tr[2].val=2e9,tr[2].siz=1,tr[2].fa=1; n=read();int typ,x; for(int i=1;i<=n;i++){ typ=read(),x=read(); switch(typ){ case 1:insert(x);break; case 2:del(get_num(x));break; case 3:printf("%d\n",get_rank(x));break; case 4:printf("%d\n",tr[get_kth(x+1)].val);break; case 5:printf("%d\n",get_pre(x));break; case 6:printf("%d\n",get_suc(x)); } } }