【BZOJ 3196】二逼平衡树 线段树套splay 模板题
我写的是线段树套splay,网上很多人写的都是套treap,然而本蒟蒻并不会treap
奉上sth神犇的模板:
//bzoj3196 二逼平衡树,支持修改某个点的值,查询区间第k小值,查询区间某个值排名,查询区间某个值值前驱、后继。查询第k小值是log^3(n)的,其他都是log^2(n)的 #include <cstdio> using namespace std; const int maxn=2000000,inf=999999999; int a[maxn],f[maxn],sum[maxn],num[maxn],n,m,c[50001],root[200000],left,right,tot,son[maxn][2]; int mmin(int x,int y) { if (x<y) return x; return y; } int mmax(int x,int y) { if (x>y) return x; return y; } void rotate(int x,int w) { int y=f[x]; if (f[y]) if (y==son[f[y]][0]) son[f[y]][0]=x; else son[f[y]][1]=x; f[x]=f[y]; if (son[x][w]) f[son[x][w]]=y; son[y][1-w]=son[x][w]; f[y]=x; son[x][w]=y; sum[y]=sum[son[y][0]]+sum[son[y][1]]+num[y]; } void splay(int r,int x,int w) { int y; while (f[x]!=w) { y=f[x]; if (f[y]==w) if (x==son[y][0]) rotate(x,1); else rotate(x,0); else if (y==son[f[y]][0]) if (x==son[y][0]) {rotate(y,1); rotate(x,1); } else {rotate(x,0); rotate(x,1);} else if (x==son[y][0]) {rotate(x,1); rotate(x,0);} else {rotate(y,0); rotate(x,0);} } sum[x]=sum[son[x][0]]+sum[son[x][1]]+num[x]; if (w==0) root[r]=x; } void insert(int r,int value) { int x=root[r]; if (x==0) { a[++tot]=value; num[tot]=sum[tot]=1; root[r]=tot; return; } while (a[x]!=value) if (a[x]>value) {if (son[x][0]) x=son[x][0]; else break;} else {if (son[x][1]) x=son[x][1]; else break;} if (a[x]==value) {num[x]++; splay(r,x,0); return;} a[++tot]=value; f[tot]=x; num[tot]=sum[tot]=1; if (value<a[x]) son[x][0]=tot; else son[x][1]=tot; splay(r,tot,0); } void del(int r,int value) { int x=root[r]; while (a[x]!=value) if (a[x]>value) {if (son[x][0]) x=son[x][0]; else break;} else {if (son[x][1]) x=son[x][1]; else break;} splay(r,x,0); if (num[x]>1) {num[x]--; return;} if (!son[x][0]) {root[r]=son[x][1]; f[son[x][1]]=0; return;} if (!son[x][1]) {root[r]=son[x][0]; f[son[x][0]]=0; return;} int y=son[x][0]; while (son[y][1]) y=son[y][1]; splay(r,y,x); son[y][1]=son[x][1]; f[y]=0;//一定记得f[y]=0 f[son[y][1]]=y; sum[y]=sum[son[y][0]]+sum[son[y][1]]+num[y]; root[r]=y; } void build(int now,int l,int r) { //printf("now=%d l=%d r=%d\n",now,l,r); int i; for (i=l;i<=r;i++) insert(now,c[i]);//对每个线段树区间都建立一棵平衡树 //print(now); if (l==r) return; int mid=(l+r)>>1; i=now<<1; build(i,l,mid); build(i+1,mid+1,r); } int find(int r,int value) { int x=root[r]; while (a[x]!=value) if (a[x]>value) {if (son[x][0]) x=son[x][0]; else break;} else {if (son[x][1]) x=son[x][1]; else break;} splay(r,x,0); return x; } int rank(int r,int k) { int x=find(r,k); if (a[x]>=k) return (sum[son[x][0]]); return (sum[son[x][0]]+num[x]); } int getrank(int now,int l,int r,int k) { if (l>=left&&r<=right) return (rank(now,k)); int mid=(l+r)>>1,w=now<<1,ret=0; if (left<=mid) ret=getrank(w,l,mid,k); if (right>mid) ret+=getrank(w+1,mid+1,r,k); return ret; } void change(int now,int l,int r,int pos,int value) { del(now,c[pos]); insert(now,value); if (l==r) return; int mid=(l+r)>>1,w=now<<1; if (pos<=mid) change(w,l,mid,pos,value); else change(w+1,mid+1,r,pos,value); } int ppre(int r,int value) { int x=root[r]; while (a[x]!=value) if (a[x]>value) {if (son[x][0]) x=son[x][0]; else break;} else {if (son[x][1]) x=son[x][1]; else break;} if (a[x]==value) { splay(r,x,0); if (!son[x][0]) return 1; x=son[x][0]; while (son[x][1]) x=son[x][1]; return x; } while (f[x]&&a[x]>value) x=f[x]; if (a[x]<value) return x; return 1; } int ssucc(int r,int value) { int x=root[r]; while (a[x]!=value) if (a[x]>value) {if (son[x][0]) x=son[x][0]; else break;} else {if (son[x][1]) x=son[x][1]; else break;} if (a[x]==value) { splay(r,x,0); if (!son[x][1]) return 2; x=son[x][1]; while (son[x][0]) x=son[x][0]; return x; } while (f[x]&&a[x]<value) x=f[x]; if (a[x]>value) return x; return 2; } int pre(int now,int l,int r,int value) { if (l>=left&&r<=right) { int x=ppre(now,value); if (a[x]<value) return a[x]; x=find(now,4); return -1; } int mid=(l+r)>>1,w=now<<1,ret=-1; if (left<=mid) ret=mmax(ret,pre(w,l,mid,value)); if (right>mid) ret=mmax(ret,pre(w+1,mid+1,r,value)); return ret; } int succ(int now,int l,int r,int value) { if (l>=left&&r<=right) { int x=ssucc(now,value); if (a[x]>value) return (a[x]); return inf; } int mid=(l+r)>>1,w=now<<1,ret=inf; if (left<=mid) ret=mmin(ret,succ(w,l,mid,value)); if (right>mid) ret=mmin(ret,succ(w+1,mid+1,r,value)); return ret; } int main() { scanf("%d%d",&n,&m); int i,kind,x,y,z,min,max,ans,mid; for (i=1;i<=n;i++) scanf("%d",&c[i]); tot=2; a[1]=-1,a[2]=999999999; build(1,1,n); for (i=1;i<=m;i++) { scanf("%d%d%d",&kind,&x,&y); if (kind!=3) scanf("%d",&z); if (kind==1) { left=x,right=y; printf("%d\n",getrank(1,1,n,z)+1); continue; } if (kind==2) { left=x,right=y; min=0,max=100000000; while (min<=max) { mid=(min+max)>>1; if (getrank(1,1,n,mid)<z) {ans=mid; min=mid+1;} else max=mid-1; } printf("%d\n",ans); continue; } if (kind==3) { change(1,1,n,x,y); c[x]=y; continue; } if (kind==4) { left=x,right=y; printf("%d\n",pre(1,1,n,z)); continue; } left=x,right=y; printf("%d\n",succ(1,1,n,z)); } return 0; }
然后是我的AC code:
/*调了好久,一部分原因是做这道题时学长讲新课,没时间调这道题。今天终于调出来了,然后把三篇平衡树的博客一起写出来了。这个题的题解在上面sth神犇的模板里有,主要是那个二分k值是本题的特色。我这道题一直WA的原因是快速读入没有判断负号(╯‵□′)╯︵┻━┻*/
#include<cstdio> #include<cctype> #include<cstring> #include<algorithm> using namespace std; inline const int max(const int &a,const int &b){return a>b?a:b;} inline const int min(const int &a,const int &b){return a<b?a:b;} struct node{ node(); node *fa,*ch[2]; int sum,d; short pl(){return this==fa->ch[1];} void count(){sum=ch[0]->sum+ch[1]->sum+1;} }*null; node::node(){fa=ch[1]=ch[0]=null;sum=0;} int data[50003],N,M; node *ROOT[50000<<2]; int getint(){char c;int fh=1;while(!isdigit(c=getchar()))if (c=='-')fh=-1;int a=c-'0';while(isdigit(c=getchar()))a=a*10+c-'0';return a*fh;} namespace Splay{ void Builda(){ null=new node; *null=node(); } void rotate(node *k,int rt){ node *r=k->fa; if (k==null||r==null) return; int x=k->pl()^1; r->ch[x^1]=k->ch[x]; r->ch[x^1]->fa=r; if (r->fa==null) ROOT[rt]=k; else r->fa->ch[r->pl()]=k; k->fa=r->fa; r->fa=k; k->ch[x]=r; r->count(); k->count(); } void splay(int rt,node *r,node *tar=null){ for (;r->fa!=tar;rotate(r,rt)) if (r->fa->fa!=tar)rotate(r->pl()==r->fa->pl()?r->fa:r,rt); } void insert(int rt,int x){ node *r=ROOT[rt]; if (ROOT[rt]==null){ ROOT[rt]=new node; *ROOT[rt]=node(); ROOT[rt]->d=x; ROOT[rt]->count(); return; } while (1){ int c; if (x<r->d) c=0; else c=1; if (r->ch[c]==null){ r->ch[c]=new node; *r->ch[c]=node(); r->ch[c]->fa=r; r->ch[c]->d=x; splay(rt,r->ch[c]); return; }else r=r->ch[c]; } } void build(int rt,int l,int r){ ROOT[rt]=null; for(int i=l;i<=r;++i) insert(rt,data[i]); } node *kth(int rt,int x){ node *r=ROOT[rt]; while (r!=null){ if (x<r->d) r=r->ch[0]; else if (x>r->d) r=r->ch[1]; else return r; }return r; } node *rightdown(node *r){ while (r->ch[1]!=null){ r=r->ch[1]; }return r; } void deletee(int rt,int x){ node *r=kth(rt,x); splay(rt,r); if ((r->ch[0]==null)&&(r->ch[1]==null)){ ROOT[rt]=null; delete r; }else if (r->ch[0]==null){ r->ch[1]->fa=null; ROOT[rt]=r->ch[1]; delete r; }else if (r->ch[1]==null){ r->ch[0]->fa=null; ROOT[rt]=r->ch[0]; delete r; }else{ splay(rt,rightdown(r->ch[0]),ROOT[rt]); r->ch[0]->ch[1]=r->ch[1]; r->ch[1]->fa=r->ch[0]; r->ch[0]->fa=null; r->ch[0]->count(); ROOT[rt]=r->ch[0]; delete r; } } int predd(node *r,int x){ if (r==null) return -1; if (x<=r->d) return predd(r->ch[0],x); else return max(r->d,predd(r->ch[1],x)); } int pross(node *r,int x){ if (r==null) return 1E8+10; if (x>=r->d) return pross(r->ch[1],x); else return min(r->d,pross(r->ch[0],x)); } node *get1(node *r,int x){ if (r==null) return null; if (x<r->d) return get1(r->ch[0],x); if (x>r->d) return get1(r->ch[1],x); node *rr=get1(r->ch[0],x); if (rr!=null) return rr; else return r; } int quee1(int rt,int x){ node *r=get1(ROOT[rt],x); if (r!=null){ splay(rt,r); return r->ch[0]->sum; }else{ insert(rt,x); int anss=ROOT[rt]->ch[0]->sum; deletee(rt,x); return anss; } } int quee2(int rt,int x){ node *r=get1(ROOT[rt],x); if (r!=null){ splay(rt,r); return r->ch[0]->sum; }else{ insert(rt,x); int anss=ROOT[rt]->ch[0]->sum; deletee(rt,x); return anss; } } } void buildtree(int l,int r,int rt){ Splay::build(rt,l,r); if (l==r) {Splay::build(rt,l,r);return;} int mid=(l+r)>>1; buildtree(l,mid,rt<<1); buildtree(mid+1,r,rt<<1|1); } int que1(int L,int R,int k,int l,int r,int rt){ if ((L<=l)&&(r<=R)) return Splay::quee1(rt,k); int mid=(l+r)>>1,s=0; if (L<=mid) s+=que1(L,R,k,l,mid,rt<<1); if (R>mid) s+=que1(L,R,k,mid+1,r,rt<<1|1); return s; } int que2(int L,int R,int k,int l,int r,int rt){ if ((L<=l)&&(r<=R)) return Splay::quee2(rt,k); int mid=(l+r)>>1,s=0; if (L<=mid) s+=que2(L,R,k,l,mid,rt<<1); if (R>mid) s+=que2(L,R,k,mid+1,r,rt<<1|1); return s; } void que3(int pos,int k,int l,int r,int rt){ if ((l<=pos)&&(pos<=r)){ Splay::deletee(rt,data[pos]); Splay::insert(rt,k); }if (l==r) return; int mid=(l+r)>>1; if (pos<=mid) que3(pos,k,l,mid,rt<<1); if (pos>mid) que3(pos,k,mid+1,r,rt<<1|1); } int que4(int L,int R,int k,int l,int r,int rt){ if ((L<=l)&&(r<=R)) return Splay::predd(ROOT[rt],k); int mid=(l+r)>>1,s=-1; if (L<=mid) s=que4(L,R,k,l,mid,rt<<1); if (R>mid) s=max(s,que4(L,R,k,mid+1,r,rt<<1|1)); return s; } int que5(int L,int R,int k,int l,int r,int rt){ if ((L<=l)&&(r<=R)) return Splay::pross(ROOT[rt],k); int mid=(l+r)>>1,s=1E8+10; if (L<=mid) s=que5(L,R,k,l,mid,rt<<1); if (R>mid) s=min(s,que5(L,R,k,mid+1,r,rt<<1|1)); return s; } int main(){ Splay::Builda(); N=getint();M=getint(); for(int i=1;i<=N;++i)data[i]=getint(); buildtree(1,N,1); while (M){M--; int x=getint(); switch (x){ int l,r,k,pos,ans,left,right,mid; case 1: l=getint(),r=getint(),k=getint(); printf("%d\n",que1(l,r,k,1,N,1)+1); break; case 2: l=getint(),r=getint(),k=getint(),left=0,right=1E8; while (left<=right){ mid=(left+right)>>1; if (que2(l,r,mid,1,N,1)+1<=k) ans=mid,left=mid+1; else right=mid-1; }printf("%d\n",ans); break; case 3: pos=getint(),k=getint(); que3(pos,k,1,N,1); data[pos]=k; break; case 4: l=getint(),r=getint(),k=getint(); printf("%d\n",que4(l,r,k,1,N,1)); break; case 5: l=getint(),r=getint(),k=getint(); printf("%d\n",que5(l,r,k,1,N,1)); break; } } return 0; }
这样就可以了
NOI 2017 Bless All