二逼平衡树 题解(树套树)
我 想 扇 死 自 己
void up(int x) { if(x) { size[x]=cnt[x];//我TM这行忘了 if(son[x][0])size[x]+=size[son[x][0]]; if(son[x][1])size[x]+=size[son[x][1]]; } }
4个小时!调一道模板!我敲里码!
上道splay刚因为细节打错浪费了3个小时时间,这次就又**重现了
不多说了,先把splay抄上10遍,手写!
-----------以下是正经题解----------------
第一道树套树:线段树套splay
对于线段树的每一段区间建splay维护这段的信息
在合并时:
排名相加;
前驱取max;
后继取min;
比较麻烦的是查询数值,需要二分答案.
以数值为值域进行二分,不断询问mid的排名来缩小范围。
#include<cstdio> #include<iostream> #include<algorithm> #include<cstring> using namespace std; const int N=4000005,inf=1e9; int n,m,a[N]; int root[N],son[N][3],fa[N],key[N],size[N],type,cnt[N]; void clear(int x) { if(!x)return ; fa[x]=cnt[x]=son[x][0]=son[x][1]=size[x]=key[x]=0; } int pre(int k) { int now=son[root[k]][0]; while(son[now][1])now=son[now][1]; return now; } bool judge(int x) { return son[fa[x]][1]==x; } void up(int x) { if(x) { size[x]=cnt[x]; if(son[x][0])size[x]+=size[son[x][0]]; if(son[x][1])size[x]+=size[son[x][1]]; } } void rotate(int x) { int old=fa[x],oldf=fa[old],lr=judge(x); son[old][lr]=son[x][lr^1]; fa[son[old][lr]]=old; son[x][lr^1]=old; fa[old]=x; fa[x]=oldf; if(oldf)son[oldf][son[oldf][1]==old]=x; up(old);up(x); } void splay(int k,int x) { for(int f;f=fa[x];rotate(x)) if(fa[f])rotate(judge(x)==judge(f)?f:x); root[k]=x; } void ins(int k,int x) { if(!root[k]) { type++; key[type]=x; root[k]=type; cnt[type]=size[type]=1; fa[type]=son[type][0]=son[type][1]=0; return ; } int now=root[k],f=0; while(1) { if(x==key[now]) { cnt[now]++; up(now); up(f); splay(k,now); return ; } f=now;now=son[now][key[now]<x]; if(!now) { type++; size[type]=cnt[type]=1; son[type][0]=son[type][1]=0; son[f][x>key[f]]=type; fa[type]=f; key[type]=x; up(f);splay(k,type); return ; } } } int getrank(int k,int x) { int now=root[k],ans=0; while(1) { if(!now)return ans; if(x==key[now])return (son[now][0]?size[son[now][0]]:0)+ans; else if(x>key[now]) { ans+=(son[now][0]?size[son[now][0]]:0)+cnt[now]; now=son[now][1]; } else if(x<key[now])now=son[now][0]; } } int findpos(int k,int x) { int now=root[k]; while(1) { if(x==key[now])return now; else if(x<key[now])now=son[now][0]; else now=son[now][1]; } } int findpre(int k,int x) { int now=root[k],ans=0; while(now) { if(key[now]<x) { if(ans<key[now])ans=key[now]; now=son[now][1]; } else now=son[now][0]; } return ans; } int findnxt(int k,int x) { int now=root[k],ans=inf; while(now) { if(key[now]>x) { if(ans>key[now])ans=key[now]; now=son[now][0]; } else now=son[now][1]; } return ans; } void del(int k,int x) { int now=findpos(k,x); splay(k,now); if(cnt[root[k]]>1) { cnt[root[k]]--; up(root[k]); return ; } else if(!son[root[k]][0]&&(!son[root[k]][1])) { clear(root[k]); root[k]=0; return ; } int old=root[k]; if(son[root[k]][0]*son[root[k]][1]==0) { if(!son[root[k]][0])root[k]=son[root[k]][1]; else root[k]=son[root[k]][0]; fa[root[k]]=0; clear(old); return ; } int L=pre(k); splay(k,L); son[root[k]][1]=son[old][1]; fa[son[old][1]]=root[k]; clear(old); up(root[k]); } #define ls(k) k<<1 #define rs(k) k<<1|1 void update(int k,int l,int r,int pos,int val) { ins(k,val); if(l==r)return ; int mid=l+r>>1; if(pos<=mid)update(ls(k),l,mid,pos,val); else update(rs(k),mid+1,r,pos,val); return ; } int rank(int k,int l,int r,int L,int R,int val) { if(l>=L&&r<=R) { int res=getrank(k,val); return res; } int mid=l+r>>1,res=0; if(L<=mid)res+=rank(ls(k),l,mid,L,R,val); if(R>mid)res+=rank(rs(k),mid+1,r,L,R,val); return res; } void modify(int k,int l,int r,int pos,int val) { del(k,a[pos]); ins(k,val); if(l==r)return ; int mid=l+r>>1; if(pos<=mid)modify(ls(k),l,mid,pos,val); else modify(rs(k),mid+1,r,pos,val); } int getpre(int k,int l,int r,int L,int R,int val) { if(l>=L&&r<=R)return findpre(k,val); int mid=l+r>>1,res=0; if(L<=mid)res=max(res,getpre(ls(k),l,mid,L,R,val)); if(R>mid)res=max(res,getpre(rs(k),mid+1,r,L,R,val)); return res; } int getnxt(int k,int l,int r,int L,int R,int val) { if(l>=L&&r<=R)return findnxt(k,val); int mid=l+r>>1,res=inf; if(L<=mid)res=min(res,getnxt(ls(k),l,mid,L,R,val)); if(R>mid)res=min(res,getnxt(rs(k),mid+1,r,L,R,val)); return res; } inline int read() { int x=0,f=1;char ch=getchar(); while(ch<'0'||ch>'9') {if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();} return x*f; } int main() { n=read();m=read(); int op,maxx=0; for(int i=1;i<=n;i++) { a[i]=read(); update(1,1,n,i,a[i]); maxx=max(maxx,a[i]); } while(m--) { op=read(); if(op==1) { int l=read(),r=read(),val=read(); printf("%d\n",rank(1,1,n,l,r,val)+1); } else if(op==2) { int l=read(),r=read(),val=read(); int L=0,R=maxx+1; while(L!=R) { int mid=L+R>>1; int res=rank(1,1,n,l,r,mid); //cout<<"***"<<res<<endl; if(res<val)L=mid+1; else R=mid; } printf("%d\n",L-1); } else if(op==3) { int pos=read(),val=read();modify(1,1,n,pos,val); a[pos]=val; maxx=max(maxx,val); } else if(op==4) { int l=read(),r=read(),val=read(); printf("%d\n",getpre(1,1,n,l,r,val)); } else if(op==5) { int l=read(),r=read(),val=read(); printf("%d\n",getnxt(1,1,n,l,r,val)); } } return 0; }
好了。
从上面那段简短而狗屁不通的“题解”和几乎是抄来的代码可以看出来,是什么让当时的我那么垃圾。
不求甚解、生搬硬套、懒于思考、依赖题解。
装模作样打个Splay,考场上没板子真的写得出来?
如果像本题一样,把普通平衡树的操作放到区间上,显然是无法只用平衡树维护的。解决区间问题最有力的武器就是线段树,所以考虑线段树套平衡树解决。
对每个线段树区间建一棵平衡树。建树时直接把所有区间都插入该区间的所有元素,单点修改时把沿路的所有线段树区间上的平衡树都进行改动(删除再插入)。
对于剩下的查询操作,求排名显然可以转化为所有区间小于该数的元素个数之和+1,即$( \sum (每个区间求排名结果-1)) +1$,前驱应当是所有区间结果的最大值,同理后继就是最小值。
但用相同的方式求K大是不太可行的,考虑牺牲一下时间复杂度进行二分答案,每次二分出一个数check它的排名即可。这样的话是3个$log$。
平衡树使用的是替罪羊树,一是确实好写且容易封装,二是动态开点删点可以避免内存超限。这样就可以直接粗暴地扔到结构体里而不用像Splay一样使用$root[]$数组了。
上面瞎写的东西我没有删。给自己和大家一个警示以及反面典型。
#include<cstdio> #include<iostream> #include<cstring> #include<vector> using namespace std; int read() { int x=0,f=1;char ch=getchar(); while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();} while(isdigit(ch))x=x*10+ch-'0',ch=getchar(); return x*f; } const int N=1e5+5,inf=2147483647; const double al=0.7; int n,m,a[N]; struct Scapegoat { struct node { node *l,*r; int val,size,cnt; bool del; bool bad() { return l->cnt>al*cnt+5||r->cnt>al*cnt+5; } void up() { size=!del+l->size+r->size; cnt=1+l->cnt+r->cnt; } }; node *null,**badtag; void dfs(node *k,vector<node*> &v) { if(k==null)return ; dfs(k->l,v); if(!k->del)v.push_back(k); dfs(k->r,v); if(k->del)delete k; } node *build(vector<node*> &v,int l,int r) { if(l>=r)return null; int mid=l+r>>1; node *k=v[mid]; k->l=build(v,l,mid); k->r=build(v,mid+1,r); k->up(); return k; } void rebuild(node* &k) { vector<node*> v; dfs(k,v); k=build(v,0,v.size()); } void insert(int x,node* &k) { if(k==null) { k=new node; k->l=k->r=null; k->del=0; k->size=k->cnt=1; k->val=x; return ; } ++k->size;++k->cnt; if(x>=k->val)insert(x,k->r); else insert(x,k->l); if(k->bad())badtag=&k; else if(badtag!=&null) k->cnt-=(*badtag)->cnt-(*badtag)->size; } void ins(int x,node* &k) { badtag=&null; insert(x,k); if(badtag!=&null)rebuild(*badtag); } int getrk(node *now,int x) { int ans=1; while(now!=null) { if(now->val>=x)now=now->l; else { ans+=now->l->size+!now->del; now=now->r; } } return ans; } int kth(node *now,int x) { while(now!=null) { if(!now->del&&now->l->size+1==x) return now->val; if(now->l->size>=x)now=now->l; else { x-=now->l->size+!now->del; now=now->r; } } return -1; } void erase(node *k,int rk) { if(!k->del&&rk==k->l->size+1) { k->del=1; --k->size; return ; } --k->size; if(rk<=k->l->size+!k->del)erase(k->l,rk); else erase(k->r,rk-k->l->size-!k->del); } node* root; Scapegoat() { null=new node; root=null; } }s[N<<3]; #define ls(k) (k)<<1 #define rs(k) (k)<<1|1 void build(int k,int l,int r) { for(int i=l;i<=r;i++) s[k].ins(a[i],s[k].root); if(l==r)return ; int mid=l+r>>1; build(ls(k),l,mid); build(rs(k),mid+1,r); } int askrk(int k,int l,int r,int L,int R,int val) { if(L<=l&&R>=r)return s[k].getrk(s[k].root,val)-1; int mid=l+r>>1,res=0; if(L<=mid)res+=askrk(ls(k),l,mid,L,R,val); if(R>mid)res+=askrk(rs(k),mid+1,r,L,R,val); return res; } void update(int k,int l,int r,int pos,int val) { s[k].erase(s[k].root,s[k].getrk(s[k].root,a[pos])); s[k].ins(val,s[k].root); if(l==r)return ; int mid=l+r>>1; if(pos<=mid)update(ls(k),l,mid,pos,val); else update(rs(k),mid+1,r,pos,val); } int askpre(int k,int l,int r,int L,int R,int val) { if(L<=l&&R>=r)return s[k].kth(s[k].root,s[k].getrk(s[k].root,val)-1); int res=-inf,mid=l+r>>1; if(L<=mid) { int ret=askpre(ls(k),l,mid,L,R,val); if(ret==-1)res=max(res,-inf); else res=max(res,ret); } if(R>mid) { int ret=askpre(rs(k),mid+1,r,L,R,val); if(ret==-1)res=max(res,-inf); else res=max(res,ret); } return res; } int asknxt(int k,int l,int r,int L,int R,int val) { if(L<=l&&R>=r)return s[k].kth(s[k].root,s[k].getrk(s[k].root,val+1)); int res=inf,mid=l+r>>1; if(L<=mid) { int ret=asknxt(ls(k),l,mid,L,R,val); if(ret==-1)res=min(res,inf); else res=min(res,ret); } if(R>mid) { int ret=asknxt(rs(k),mid+1,r,L,R,val); if(ret==-1)res=min(res,inf); else res=min(res,ret); } return res; } int askth(int L,int R,int val) { int l=0,r=1e8,res; while(l<=r) { int mid=l+r>>1; if(askrk(1,1,n,L,R,mid)+1<=val)res=mid,l=mid+1; else r=mid-1; } return res; } int main() { n=read();m=read(); for(int i=1;i<=n;i++) a[i]=read(); build(1,1,n); while(m--) { int op=read(); if(op==1){int l=read(),r=read(),K=read();printf("%d\n",askrk(1,1,n,l,r,K)+1);} if(op==2){int l=read(),r=read(),K=read();printf("%d\n",askth(l,r,K));} if(op==3){int pos=read(),K=read();update(1,1,n,pos,K);a[pos]=K;} if(op==4){int l=read(),r=read(),K=read();printf("%d\n",askpre(1,1,n,l,r,K));} if(op==5){int l=read(),r=read(),K=read();printf("%d\n",asknxt(1,1,n,l,r,K));} } return 0; }
兴许青竹早凋,碧梧已僵,人事本难防。