【模板】平衡树Splay

Splay

求一个数的前驱后继和排名和第k大的数

普通平衡树

#include<algorithm>
#include<iostream>
#include<cstring>
#include<iomanip>
#include<sstream>
#include<cstdio>
#include<string>
#include<vector>
#include<bitset>
#include<queue>
#include<cmath>
#include<stack>
#include<set>
#include<map>
#define rep(i,x,n) for(int i=x;i<=n;i++)
#define per(i,n,x) for(int i=n;i>=x;i--)
#define sz(a) int(a.size())
#define rson mid+1,r,p<<1|1
#define pii pair<int,int>
#define lson l,mid,p<<1
#define ll long long
#define pb push_back
#define mp make_pair
#define se second
#define fi first
using namespace std;
const double eps=1e-8;
const int mod=1e9+7;
const int N=1e5+10;
const int inf=1e9;
int n,m;
int rt,tot,ch[N][2],sz[N],val[N],fa[N],cnt[N];
struct Splay{
    void maintain(int x){
        sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x];
    }
    bool get(int x){
        return x==ch[fa[x]][1];
    }
    void clear(int x){
        ch[x][0]=ch[x][1]=fa[x]=val[x]=sz[x]=cnt[x]=0;
    }
    void rotate(int x){
        int y=fa[x],z=fa[y],chk=get(x);
        ch[y][chk]=ch[x][chk^1];
        fa[ch[x][chk^1]]=y;
        ch[x][chk^1]=y;
        fa[y]=x;
        fa[x]=z;
        if(z) ch[z][ y==ch[z][1] ]=x;
        maintain(x);
        maintain(y);
    }
    void splay(int x,int goal=0){
        for(int f;(f=fa[x])!=goal;rotate(x)){
            if(fa[f]!=goal) rotate(get(x)==get(f)?f:x); 
        }
        if(goal==0) rt=x;
    }
    void ins(int k){
        if(!rt){
            val[++tot]=k;
            cnt[tot]++;
            rt=tot;
            maintain(rt);
            return;
        }
        int cnr=rt,f=0;
        while(1){
            if(val[cnr]==k){
                cnt[cnr]++;
                maintain(cnr);
                maintain(f);
                splay(cnr);
                break;
            }
            f=cnr;
            cnr=ch[cnr][val[cnr]<k];
            if(!cnr){
                val[++tot]=k;
                cnt[tot]++;
                fa[tot]=f;
                ch[f][val[f]<k]=tot;
                maintain(tot);
                maintain(f);
                splay(tot);
                break;
            }
        }
    }
    int rk(int k){
        int res=0,cnr=rt;
        while(1){
            if(k<val[cnr]){
                cnr=ch[cnr][0];
            }else{
                res+=sz[ch[cnr][0]];
                if(k==val[cnr]){
                    splay(cnr);
                    return res+1;
                }
                res+=cnt[cnr];
                cnr=ch[cnr][1];
            }
        }
    }
    int kth(int k){
        int cnr=rt;
        while(1){
            if(ch[cnr][0]&&sz[ch[cnr][0]]>=k){
                cnr=ch[cnr][0];
            }else{
                k-=cnt[cnr]+sz[ch[cnr][0]];
                if(k<=0){
                    splay(cnr);
                    return val[cnr];
                }
                cnr=ch[cnr][1];
            }
        }
    }
    int pre(){
        int cnr=ch[rt][0];
        while(ch[cnr][1]) cnr=ch[cnr][1];
        splay(cnr);
        return cnr;
    }
    int nxt(){
        int cnr=ch[rt][1];
        while(ch[cnr][0]) cnr=ch[cnr][0];
        splay(cnr);
        return cnr;
    }
    void del(int k){
        rk(k);
        if(cnt[rt]>1){
            cnt[rt]--;
            maintain(rt);
            return;
        }
        if(!ch[rt][0]&&!ch[rt][1]){
            clear(rt);
            rt=0;
            return;
        }
        if(!ch[rt][0]){
            int cnr=rt;
            rt=ch[rt][1];
            fa[rt]=0;
            clear(cnr);
            return;
        }
        if(!ch[rt][1]){
            int cnr=rt;
            rt=ch[rt][0];
            fa[rt]=0;
            clear(cnr);
            return;
        }
        int cnr=rt,x=pre();
        splay(x);
        fa[ch[cnr][1]]=x;
        ch[x][1]=ch[cnr][1];
        clear(cnr);
        maintain(rt);
    }
}S;
int main(){
    //ios::sync_with_stdio(false);
    //freopen("in","r",stdin);
    scanf("%d",&n);
    rep(i,1,n){
        int opt,x;
        scanf("%d%d",&opt,&x);
        if(opt==1){
            S.ins(x);
        }else if(opt==2){
            S.del(x);
        }else if(opt==3){
            S.ins(x);
            printf("%d\n",S.rk(x));
            S.del(x);
        }else if(opt==4){
            printf("%d\n",S.kth(x));
        }else if(opt==5){
            S.ins(x);
            printf("%d\n",val[S.pre()]);
            S.del(x);
        }else{
            S.ins(x);
            printf("%d\n",val[S.nxt()]);
            S.del(x);
        }
    }
    return 0;
}

区间翻转

文艺平衡树

#include<algorithm>
#include<iostream>
#include<cstring>
#include<iomanip>
#include<sstream>
#include<cstdio>
#include<string>
#include<vector>
#include<bitset>
#include<queue>
#include<cmath>
#include<stack>
#include<set>
#include<map>
#define rep(i,x,n) for(int i=x;i<=n;i++)
#define per(i,n,x) for(int i=n;i>=x;i--)
#define sz(a) int(a.size())
#define rson mid+1,r,p<<1|1
#define pii pair<int,int>
#define lson l,mid,p<<1
#define ll long long
#define pb push_back
#define mp make_pair
#define se second
#define fi first
using namespace std;
const double eps=1e-8;
const int mod=1e9+7;
const int N=1e5+10;
const int inf=1e9;
int n,m;
int rt,tot,ch[N][2],sz[N],val[N],fa[N],cnt[N],tag[N];
struct Splay{
    void maintain(int x){
        sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x];
    }
    bool get(int x){
        return x==ch[fa[x]][1];
    }
    void clear(int x){
        ch[x][0]=ch[x][1]=fa[x]=val[x]=sz[x]=cnt[x]=0;
    }
    void pd(int x){
        if(x&&tag[x]){
            tag[ch[x][0]]^=1;
            tag[ch[x][1]]^=1;
            swap(ch[x][0],ch[x][1]);
            tag[x]=0;
        }
    }
    void rotate(int x){
        int y=fa[x],z=fa[y],chk=get(x);
        pd(x),pd(y);
        ch[y][chk]=ch[x][chk^1];
        fa[ch[x][chk^1]]=y;
        ch[x][chk^1]=y;
        fa[y]=x;
        fa[x]=z;
        if(z) ch[z][ y==ch[z][1] ]=x;
        maintain(y);
        maintain(x);
    }
    void splay(int x,int goal=0){
        for(int f;(f=fa[x])!=goal;rotate(x)){
            if(fa[f]!=goal) rotate(get(x)==get(f)?f:x); 
        }
        if(goal==0) rt=x;
    }
    void ins(int k){
        if(!rt){
            val[++tot]=k;
            cnt[tot]++;
            rt=tot;
            maintain(rt);
            return;
        }
        int cnr=rt,f=0;
        while(1){
            if(val[cnr]==k){
                cnt[cnr]++;
                maintain(cnr);
                maintain(f);
                splay(cnr);
                break;
            }
            f=cnr;
            cnr=ch[cnr][val[cnr]<k];
            if(!cnr){
                val[++tot]=k;
                cnt[tot]++;
                fa[tot]=f;
                ch[f][val[f]<k]=tot;
                maintain(tot);
                maintain(f);
                splay(tot);
                break;
            }
        }
    }
    int kth(int k){
        int cnr=rt;
        while(1){
            pd(cnr);
            if(ch[cnr][0]&&sz[ch[cnr][0]]>=k){
                cnr=ch[cnr][0];
            }else{
                k-=cnt[cnr]+sz[ch[cnr][0]];
                if(k<=0){
                    return cnr;
                }
                cnr=ch[cnr][1];
            }
        }
    }
    void reverse(int x,int y){
        int l=kth(x-1),r=kth(y+1);
        splay(l,0);
        splay(r,l);
        int pos=ch[rt][1];
        pos=ch[pos][0];
        tag[pos]^=1;
    }
    void dfs(int x){
        pd(x);
        if(ch[x][0]) dfs(ch[x][0]);
        if(val[x]!=inf&&val[x]!=-inf){
            printf("%d ",val[x]);
        }
        if(ch[x][1]) dfs(ch[x][1]);
    }
}S;
int main(){
    //ios::sync_with_stdio(false);
    //freopen("in","r",stdin);
    scanf("%d%d",&n,&m);
    S.ins(-inf);
    rep(i,1,n) S.ins(i);
    S.ins(inf);
    rep(i,1,m){
        int l,r;
        scanf("%d%d",&l,&r);
        S.reverse(l+1,r+1);
    }
    S.dfs(rt);
    return 0;
}
posted @ 2020-08-15 17:53  xyq0220  阅读(109)  评论(0编辑  收藏  举报