树套树浅谈

今天来说一下线段树套Splay。顺便我也来重新敲一遍模板。

首先,明确一下Splay套线段树用来处理什么问题。它可以支持:插入x,删除x,单点修改,查询x在区间[l,r]的排名,查询区间[l,r]中排名为k的数,以及一个数在区间[l,r]中的前驱,后继。(应该还可以查询区间和等东西,还没写过)

其实它的常数非常大,但是这是树套树中最容易理解的一种。

首先,我们知道,对于一个区间,我们给它建一棵线段树,它的每个节点维护的是区间[l,r]的信息。所以,我们对于每一个节点都建一棵Splay.(听起来就很暴力啊……)

不知道为啥我的指针版Splay套炸了,这里用普通版吧。

首先,我们先来完成最基本的Splay操作。

#include<cstdio>
#include<iostream>
#include<cstring>
#include<string>
#define inf (2147483647)
using namespace std;
const int MAXN=4e6+2;
inline void IN(int &x){
    int s=0,w=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-')w=-1;
        ch=getchar();
    }while(ch>=48&&ch<='9'){
        s=(s<<1)+(s<<3)+(ch^48);
        ch=getchar();
    }x=s*w;
}
int n,m,a[MAXN],ans,MX;
/*-------------------------Splay-------------------------------------*/
int f[MAXN],c[MAXN],s[MAXN],v[MAXN],ch[MAXN][2],rt[MAXN],tot;
inline void Splay_del(int x){f[x]=s[x]=c[x]=v[x]=ch[x][1]=ch[x][0]=0;}
inline void Splay_pushup(int x){s[x]=(ch[x][0]?s[ch[x][0]]:0)+(ch[x][1]?s[ch[x][1]]:0)+c[x];}
inline void Splay_rotate(int x){
    int y=f[x],z=f[y],k=ch[y][1]==x,v=ch[x][k^1];
    ch[y][k]=v;if(v)f[v]=y;f[x]=z;if(z)ch[z][ch[z][1]==y]=x;
    f[y]=x;ch[x][k^1]=y;Splay_pushup(y),Splay_pushup(x);
}

pushup、rotate和del我就不解释了,根据代码理解吧。

然后我们考虑Splay操作的写法。

实际上并没有什么区别,因为是多颗Splay,所以我们要多传一个当前Splay所在区间的编号,对应地,rt[i]就是原来Splay的rt。

代码:

inline void Splay(int i,int x,int top=0){
    while(f[x]!=top){
        int y=f[x],z=f[y];
        if(z!=top)(ch[z][1]==y)^(ch[y][1]==x)?Splay_rotate(y):Splay_rotate(x);
        Splay_rotate(x);
    }if(!top)rt[i]=x;
}

接下来,来讲插入。

对于插入,还是一样的流程,往下找,如果是空树,新建节点,赋值;找到值一样的,c[x]++,pushup;没有的话,新建即可。

注意这里也要传区间编号。

inline void Splay_Insert(int i,int x){
    int pos=rt[i];
    if(!rt[i]){
        rt[i]=pos=++tot;v[tot]=x;s[pos]=c[pos]=1;
        f[pos]=ch[pos][0]=ch[pos][1]=0;return;
    }int last=0;
    while(727){
        if(v[pos]==x){++c[pos];Splay_pushup(last);break;}
        last=pos;pos=ch[pos][x>v[pos]];
        if(!pos){
            pos=++tot;v[pos]=x;s[pos]=c[pos]=1;
            ch[last][x>v[last]]=pos;
            f[pos]=last;ch[pos][0]=ch[pos][1]=0;
            Splay_pushup(last);break;
        }
    }Splay(i,pos);return;
}

727是生日不要在意……

然后是求某个值在区间[l,r]里的rank.

这个东西,我们把它放到线段树里面理解:如果当前x在区间[l,r]中的一段区间[L,R]中排名为k,在区间[L1,R1]中排名为t,并且这两个区间可以合并成大区间[l,r]的话,那么,当前值x在大区间[l,r]的排名就是k+t.

所以,一个Splay_rank支持查询当前rank,一个Seg函数查询总区间的合并之后的rank。

inline int Splay_rank(int i,int k){
    int x=rt[i],cal=0;
    while(x){
        if(v[x]==k)return cal+((ch[x][0])?s[ch[x][0]]:0);
        else if(v[x]<k){
            cal+=((ch[x][0])?s[ch[x][0]]:0)+c[x];x=ch[x][1];
        }else x=ch[x][0];
    }return cal;
}
inline void Seg_rank(int x,int l,int r,int L,int R,int Kth){
    if(l==L&&r==R){ans+=Splay_rank(x,Kth);return;}
    if(R<=mid)Seg_rank(lc,l,mid,L,R,Kth);
    else if(mid<L)Seg_rank(rc,mid+1,r,L,R,Kth);
    else Seg_rank(lc,l,mid,L,mid,Kth),Seg_rank(rc,mid+1,r,mid+1,R,Kth);
}

注意,查询排名的时候,有三种情况,一个是全在左区间,一个是全在右区间,还有横跨两个区间的,不要忘记区分。

Splay找前驱后继就不说了吧……

inline int Splay_Get_pre(int i,int x){
    int pos=rt[i];while(pos){
        if(v[pos]<x){if(ans<v[pos])ans=v[pos];pos=ch[pos][1];}
        else pos=ch[pos][0];
    }return ans;
}
inline int Splay_Get_suc(int i,int x){
    int pos=rt[i];while(pos){
        if(v[pos]>x){if(ans>v[pos])ans=v[pos];pos=ch[pos][0];}
        else pos=ch[pos][1];
    }return ans;
}

然后是删除。

先上代码:

inline int Splay_find(int i,int x){
    int pos=rt[i];while(x){
        if(v[pos]==x){Splay(i,pos);return pos;}
        pos=ch[pos][x>v[pos]];
    }return 0;
}
inline int Splay_pre(int i){int x=ch[rt[i]][0];while(ch[x][1])x=ch[x][1];return x;}
inline int Splay_suc(int i){int x=ch[rt[i]][1];while(ch[x][0])x=ch[x][0];return x;}
inline void Splay_Delete(int i,int key){
    int x=Splay_find(i,key);
    if(c[x]>1){--c[x];Splay_pushup(x);return;}
    if(!ch[x][0]&&!ch[x][1]){Splay_del(rt[i]);rt[i]=0;return;}
    if(!ch[x][0]){int y=ch[x][1];rt[i]=y;f[y]=0;return;}
    if(!ch[x][1]){int y=ch[x][0];rt[i]=y;f[y]=0;return;}
    int p=Splay_pre(i);int lastrt=rt[i];
    Splay(i,p,0);ch[rt[i]][1]=ch[lastrt][1];f[ch[lastrt][1]]=rt[i];
    Splay_del(lastrt);Splay_pushup(rt[i]);
}

先找点,find函数,找完之后判断它是否还存在,如果存在,分类讨论:

左右孩子都没有,直接删掉。

只有右孩子,连接x的右孩子成为根的右孩子

只有左孩子,同理。

如果两个都有,则:

找到区间i的严格小于key的最大数,并记录根。

把这个数旋转到根,然后把之前记录的根的孩子和当前根的孩子连起来,认个爹,然后把之前根删掉,再pushup即可。

下面该线段树的查询操作了。

改天再补一发线段树的入门讲解……

首先是线段树的节点插入。

我们对每一个线段树的节点都插入v,注意是插入进它的Splay里面。

#define lc ((x)<<1)
#define rc ((x)<<1|1)
#define mid ((l+r)>>1)
inline void Seg_Insert(int x,int l,int r,int pos,int val){
    Splay_Insert(x,val);if(l==r)return;
    if(pos<=mid)Seg_Insert(lc,l,mid,pos,val);
    else Seg_Insert(rc,mid+1,r,pos,val);
}

Easily!

继续。

来看看单点修改怎样维护。

注意此处,是要先把要改的点删掉在把修改的点加进来。

并且同时更新Splay.

inline void Seg_change(int x,int l,int r,int pos,int val){
    Splay_Delete(x,a[pos]);Splay_Insert(x,val);
    if(l==r){a[pos]=val;return;}
    if(pos<=mid)Seg_change(lc,l,mid,pos,val);
    else Seg_change(rc,mid+1,r,pos,val);
}

然后是查询区间[l,r]中v的前驱,后继。

对于每一个包含的区间中,都求一遍当前Splay中v的前驱后继,每次统计max或min即可,注意处理区间问题,三种情况分类讨论,见上。

inline void Seg_pre(int x,int l,int r,int L,int R,int val){
    if(l==L&&r==R){ans=max(ans,Splay_Get_pre(x,val));return;}
    if(R<=mid)Seg_pre(lc,l,mid,L,R,val);
    else if(mid<L)Seg_pre(rc,mid+1,r,L,R,val);
    else Seg_pre(lc,l,mid,L,mid,val),Seg_pre(rc,mid+1,r,mid+1,R,val);
}
inline void Seg_suc(int x,int l,int r,int L,int R,int val){
    if(l==L&&r==R){ans=min(ans,Splay_Get_suc(x,val));return;}
    if(R<=mid)Seg_suc(lc,l,mid,L,R,val);
    else if(mid<L)Seg_suc(rc,mid+1,r,L,R,val);
    else Seg_suc(lc,l,mid,L,mid,val),Seg_suc(rc,mid+1,r,mid+1,R,val);
}

最后,查询区间排名为k的数。

此处,我们要用二分。

输入时,记录区间最大值。

L=0,R=MX+1。

二分。

每一次用ans记录当前二分值(M=(L+R)>>1)的排名。

如果ans<k,L=M,else R=mid+1.

最后返回的是L-1。

/*----------------ask-------------*/
inline int Get_Kth(int x,int y,int k){
    int L=0,R=MX+1,M;
    while(L<R){
        M=(L+R)>>1;
        ans=0;Seg_rank(1,1,n,x,y,M);
        if(ans<k)L=M+1;else R=M;
    }return L-1;
}

最后上完整代码。

#include<cstdio>
#include<iostream>
#include<cstring>
#include<string>
#define inf (2147483647)
using namespace std;
const int MAXN=4e6+2;
inline void IN(int &x){
    int s=0,w=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-')w=-1;
        ch=getchar();
    }while(ch>=48&&ch<='9'){
        s=(s<<1)+(s<<3)+(ch^48);
        ch=getchar();
    }x=s*w;
}
int n,m,a[MAXN],ans,MX;
/*-------------------------Splay-------------------------------------*/
int f[MAXN],c[MAXN],s[MAXN],v[MAXN],ch[MAXN][2],rt[MAXN],tot;
inline void Splay_del(int x){f[x]=s[x]=c[x]=v[x]=ch[x][1]=ch[x][0]=0;}
inline void Splay_pushup(int x){s[x]=(ch[x][0]?s[ch[x][0]]:0)+(ch[x][1]?s[ch[x][1]]:0)+c[x];}
inline void Splay_rotate(int x){
    int y=f[x],z=f[y],k=ch[y][1]==x,v=ch[x][k^1];
    ch[y][k]=v;if(v)f[v]=y;f[x]=z;if(z)ch[z][ch[z][1]==y]=x;
    f[y]=x;ch[x][k^1]=y;Splay_pushup(y),Splay_pushup(x);
}
inline void Splay(int i,int x,int top=0){
    while(f[x]!=top){
        int y=f[x],z=f[y];
        if(z!=top)(ch[z][1]==y)^(ch[y][1]==x)?Splay_rotate(y):Splay_rotate(x);
        Splay_rotate(x);
    }if(!top)rt[i]=x;
}
inline void Splay_Insert(int i,int x){
    int pos=rt[i];
    if(!rt[i]){
        rt[i]=pos=++tot;v[tot]=x;s[pos]=c[pos]=1;
        f[pos]=ch[pos][0]=ch[pos][1]=0;return;
    }int last=0;
    while(727){
        if(v[pos]==x){++c[pos];Splay_pushup(last);break;}
        last=pos;pos=ch[pos][x>v[pos]];
        if(!pos){
            pos=++tot;v[pos]=x;s[pos]=c[pos]=1;
            ch[last][x>v[last]]=pos;
            f[pos]=last;ch[pos][0]=ch[pos][1]=0;
            Splay_pushup(last);break;
        }
    }Splay(i,pos);return;
}
inline int Splay_rank(int i,int k){
    int x=rt[i],cal=0;
    while(x){
        if(v[x]==k)return cal+((ch[x][0])?s[ch[x][0]]:0);
        else if(v[x]<k){
            cal+=((ch[x][0])?s[ch[x][0]]:0)+c[x];x=ch[x][1];
        }else x=ch[x][0];
    }return cal;
}
inline int Splay_find(int i,int x){
    int pos=rt[i];while(x){
        if(v[pos]==x){Splay(i,pos);return pos;}
        pos=ch[pos][x>v[pos]];
    }return 0;
}
inline int Splay_pre(int i){int x=ch[rt[i]][0];while(ch[x][1])x=ch[x][1];return x;}
inline int Splay_suc(int i){int x=ch[rt[i]][1];while(ch[x][0])x=ch[x][0];return x;}
inline int Splay_Get_pre(int i,int x){
    int pos=rt[i];while(pos){
        if(v[pos]<x){if(ans<v[pos])ans=v[pos];pos=ch[pos][1];}
        else pos=ch[pos][0];
    }return ans;
}
inline int Splay_Get_suc(int i,int x){
    int pos=rt[i];while(pos){
        if(v[pos]>x){if(ans>v[pos])ans=v[pos];pos=ch[pos][0];}
        else pos=ch[pos][1];
    }return ans;
}
inline void Splay_Delete(int i,int key){
    int x=Splay_find(i,key);
    if(c[x]>1){--c[x];Splay_pushup(x);return;}
    if(!ch[x][0]&&!ch[x][1]){Splay_del(rt[i]);rt[i]=0;return;}
    if(!ch[x][0]){int y=ch[x][1];rt[i]=y;f[y]=0;return;}
    if(!ch[x][1]){int y=ch[x][0];rt[i]=y;f[y]=0;return;}
    int p=Splay_pre(i);int lastrt=rt[i];
    Splay(i,p,0);ch[rt[i]][1]=ch[lastrt][1];f[ch[lastrt][1]]=rt[i];
    Splay_del(lastrt);Splay_pushup(rt[i]);
}
/*----------------------Seg_Tree---------------------------*/
#define lc ((x)<<1)
#define rc ((x)<<1|1)
#define mid ((l+r)>>1)
inline void Seg_Insert(int x,int l,int r,int pos,int val){
    Splay_Insert(x,val);if(l==r)return;
    if(pos<=mid)Seg_Insert(lc,l,mid,pos,val);
    else Seg_Insert(rc,mid+1,r,pos,val);
}
inline void Seg_rank(int x,int l,int r,int L,int R,int Kth){
    if(l==L&&r==R){ans+=Splay_rank(x,Kth);return;}
    if(R<=mid)Seg_rank(lc,l,mid,L,R,Kth);
    else if(mid<L)Seg_rank(rc,mid+1,r,L,R,Kth);
    else Seg_rank(lc,l,mid,L,mid,Kth),Seg_rank(rc,mid+1,r,mid+1,R,Kth);
}
inline void Seg_change(int x,int l,int r,int pos,int val){
    Splay_Delete(x,a[pos]);Splay_Insert(x,val);
    if(l==r){a[pos]=val;return;}
    if(pos<=mid)Seg_change(lc,l,mid,pos,val);
    else Seg_change(rc,mid+1,r,pos,val);
}
inline void Seg_pre(int x,int l,int r,int L,int R,int val){
    if(l==L&&r==R){ans=max(ans,Splay_Get_pre(x,val));return;}
    if(R<=mid)Seg_pre(lc,l,mid,L,R,val);
    else if(mid<L)Seg_pre(rc,mid+1,r,L,R,val);
    else Seg_pre(lc,l,mid,L,mid,val),Seg_pre(rc,mid+1,r,mid+1,R,val);
}
inline void Seg_suc(int x,int l,int r,int L,int R,int val){
    if(l==L&&r==R){ans=min(ans,Splay_Get_suc(x,val));return;}
    if(R<=mid)Seg_suc(lc,l,mid,L,R,val);
    else if(mid<L)Seg_suc(rc,mid+1,r,L,R,val);
    else Seg_suc(lc,l,mid,L,mid,val),Seg_suc(rc,mid+1,r,mid+1,R,val);
}
/*----------------ask-------------*/
inline int Get_Kth(int x,int y,int k){
    int L=0,R=MX+1,M;
    while(L<R){
        M=(L+R)>>1;
        ans=0;Seg_rank(1,1,n,x,y,M);
        if(ans<k)L=M+1;else R=M;
    }return L-1;
}
/*-----------------Main--------------------*/
int main(){
    IN(n),IN(m);
    for(register int i=1;i<=n;++i){IN(a[i]);Seg_Insert(1,1,n,i,a[i]);MX=max(MX,a[i]);}
    while(m--){
        int op,x,y,v;IN(op),IN(x),IN(y);
        switch(op){
            case 1:{IN(v);ans=0;Seg_rank(1,1,n,x,y,v);printf("%d\n",ans+1);}break;
            case 2:{IN(v);printf("%d\n",Get_Kth(x,y,v));}break;
            case 3:{Seg_change(1,1,n,x,y);}break;
            case 4:{IN(v);ans=-inf;Seg_pre(1,1,n,x,y,v);printf("%d\n",ans);}break;
            case 5:{IN(v);ans=inf;Seg_suc(1,1,n,x,y,v);printf("%d\n",ans);}break;
        }
    }return 0;
} 

对于习题可以去洛谷的试炼场里找。我这个蒟蒻还是要多练啊……

posted @ 2019-05-26 18:48  Refined_heart  阅读(380)  评论(2编辑  收藏  举报