BZOJ3196 - 二逼平衡树

Portal

Description

给出一个\(n(n\leq5\times10^4)\)个数的序列,进行\(m(m\leq5\times10^4)\)次操作,操作如下:

  • 查询\(x\)在区间\([L,R]\)内的排名;
  • 查询区间\([L,R]\)内排名为\(x\)的值;
  • 修改第\(pos\)位上的数为\(x\)
  • 查询\(x\)在区间\([L,R]\)内的前驱(前驱定义为小于\(x\),且最大的数);
  • 查询\(x\)在区间\([L,R]\)内的后继(后继定义为大于\(x\),且最小的数)。

Solution

树套树模板题。
顾名思义,树套树就是将两个树形数据结构套在一起,其中外层树的每个节点都存在着一棵里层树。
本题解法多样,在此用线段树套splay解决。每个线段树中的节点都保存着一棵splay的根节点。当我们询问区间\([L,R]\)时,用线段树将其分成\(O(logn)\)个子区间,然后将每个区间的splay的信息合并起来。实际上外层的线段树根本不用建,只是用到了线段树的区间划分方法而已。

  • 查询\(x\)的排名:询问每个子区间中比\(x\)小的有多少个,求和并+1。
  • 查询排名为\(x\)的数:二分答案,查询\(mid\)的排名。时间复杂度\(O(logn\cdot logw)\)
  • 修改:对于包含\(pos\)的区间,删除原数加入\(x\)
  • 查询前驱:询问每个子区间中的前驱,取\(max\)
  • 查询后继:询问每个子区间中的后继,取\(min\)

总时间复杂度\(O(nlog^2nlogw)\)

Code

//二逼平衡树
#include <cstdio>
#include <algorithm>
using namespace std;
inline char gc()
{
    static char now[1<<16],*s,*t;
    if(s==t) {t=(s=now)+fread(now,1,1<<16,stdin); if(s==t) return EOF;}
    return *s++;
}
inline int read()
{
    int x=0,f=1; char ch=gc();
    while(ch<'0'||'9'<ch) f=(ch=='-')?-1:f,ch=gc();
    while('0'<=ch&&ch<='9') x=x*10+ch-'0',ch=gc();
    return x*f;
}
int const N1=2e6+10;
int const N2=2e5+10;
int n,m,a[N2];
int rt[N2];
namespace bTree
{
    int cnt,fa[N1],ch[N1][2],siz[N1]; int val[N1];
    int wh(int p) {return p==ch[fa[p]][1];}
    void create(int p,int x) {fa[p]=ch[p][0]=ch[p][1]=0,siz[p]=1; val[p]=x;}
    void update(int p) {siz[p]=siz[ch[p][0]]+1+siz[ch[p][1]];}
    void rotate(int p)
    {
        int q=fa[p],r=fa[q],w=wh(p);
        fa[p]=r; if(r) ch[r][wh(q)]=p;
        fa[ch[q][w]=ch[p][w^1]]=q;
        fa[ch[p][w^1]=q]=p;
        update(q),update(p);
    }
    void splay(int t,int p)
    {
        for(int q=fa[p];fa[p];rotate(p),q=fa[p]) if(fa[q]) rotate(wh(p)^wh(q)?p:q);
        rt[t]=p;
    }
    void ins(int t,int x)
    {
        int p=rt[t],q=0;
        if(!p) {create(rt[t]=++cnt,x); return;}
        while(p) q=p,p=ch[q][val[q]<x];
        create(p=++cnt,x);
        fa[ch[q][val[q]<x]=p]=q;
        splay(t,p);
    }
    void del(int t,int x)
    {
        int p=rt[t];
        while(p&&val[p]!=x) p=ch[p][val[p]<x];
        if(!p) return;
        splay(t,p); int chCnt=(ch[p][0]>0)+(ch[p][1]>0);
        if(chCnt==0) {rt[t]=0; return;}
        if(chCnt==1) {fa[rt[t]=ch[p][ch[p][1]>0]]=0; return;}
        int q=ch[p][0]; while(ch[q][1]) q=ch[q][1];
        splay(t,q); fa[ch[q][1]=ch[p][1]]=q; update(q);
    }
    int find(int t,int x)
    {
        int res=0;
        for(int p=rt[t];p;p=ch[p][val[p]<x])
            if(x>val[p]) res+=siz[ch[p][0]]+1;
        return res;
    }
    int pre(int t,int x)
    {
        int res=-1;
        for(int p=rt[t];p;p=ch[p][val[p]<x]) if(val[p]<x) res=val[p];
        return res;
    }
    int nxt(int t,int x)
    {
        int res=1e8;
        for(int p=rt[t];p;p=ch[p][val[p]<=x]) if(val[p]>x) res=val[p];
        return res;
    }
}
#define Ls (p<<1)
#define Rs ((p<<1)|1)
int sRt; int L,R;
int sFind(int p,int L0,int R0,int x)
{
    if(L<=L0&&R0<=R) return bTree::find(p,x);
    int mid=L0+R0>>1,res=0;
    if(L<=mid) res+=sFind(Ls,L0,mid,x);
    if(mid<R) res+=sFind(Rs,mid+1,R0,x);
    return res;
}
int getRnk(int x)
{
    int L1=1,R1=1e8;
    while(L1<=R1)
    {
        int mid=L1+R1>>1;
        if(sFind(sRt,1,n,mid)+1<=x) L1=mid+1;
        else R1=mid-1;
    }
    return L1-1;
}
void sChange(int p,int L0,int R0,int x0,int x)
{
    bTree::del(p,x0),bTree::ins(p,x);
    if(L0==R0) return;
    int mid=L0+R0>>1;
    if(L<=mid) sChange(Ls,L0,mid,x0,x);
    else sChange(Rs,mid+1,R0,x0,x);
}
int sPre(int p,int L0,int R0,int x)
{
    if(L<=L0&&R0<=R) return bTree::pre(p,x);
    int mid=L0+R0>>1,res=-1;
    if(L<=mid) res=max(res,sPre(Ls,L0,mid,x));
    if(mid<R) res=max(res,sPre(Rs,mid+1,R0,x));
    return res;
}
int sNxt(int p,int L0,int R0,int x)
{
    if(L<=L0&&R0<=R) return bTree::nxt(p,x);
    int mid=L0+R0>>1,res=1e8;
    if(L<=mid) res=min(res,sNxt(Ls,L0,mid,x));
    if(mid<R) res=min(res,sNxt(Rs,mid+1,R0,x));
    return res;
}
int main()
{
    n=read(),m=read();
    sRt=1;
    for(int i=1;i<=n;i++) L=i,sChange(sRt,1,n,0,a[i]=read());
    for(int i=1;i<=m;i++)
    {
        int opt=read(),k; L=read(),R=read();
        if(opt==1) printf("%d\n",sFind(sRt,1,n,read())+1);
        else if(opt==2) printf("%d\n",getRnk(read()));
        else if(opt==3) sChange(sRt,1,n,a[L],R),a[L]=R;
        else if(opt==4) printf("%d\n",sPre(sRt,1,n,read()));
        else if(opt==5) printf("%d\n",sNxt(sRt,1,n,read()));
    }
    return 0;
}

P.S.

要算好数组大小!如果不循环用点的话,初始时会有\(nlogn\)个点,每次修改会新建\(logn\)个点。

posted @ 2018-03-30 21:14  VisJiao  阅读(186)  评论(0编辑  收藏  举报