bzoj3224

时间比较:SBT:无旋treap:splay:treap

 

 

这道题是看着别人的板子大的,思路不难,但过程里有很多细节要注意

首先是treap

复制代码
/**************************************************************
    Problem: 3224
    Language: C++
    Result: Accepted
    Time:300 ms
    Memory:3164 kb
****************************************************************/
 
#include<cstdio>
#include<cctype>
#include<algorithm>
using namespace std;
int n,rt,size;
struct data{
    int l,r,val,cnt,siz,rnd;
}tr[100005];
 
inline int read(){
    char ch=getchar();int k=0,f=1;
    while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
    while(isdigit(ch)){k=(k<<1)+(k<<3)+ch-'0';ch=getchar();}
    return k*f;
}
 
inline void update(int root){
    tr[root].siz=tr[tr[root].l].siz+tr[tr[root].r].siz+tr[root].cnt;
}
 
inline int rand(){
    static int seed = 2333;
    return seed = (int)((((seed ^ 998244353) + 19260817ll) * 19890604ll) % 1000000007);
}
 
void lturn(int &root){
    int t=tr[root].r;tr[root].r=tr[t].l;tr[t].l=root;
    tr[t].siz=tr[root].siz;update(root);root=t;
}
 
void rturn(int &root){
    int t=tr[root].l;tr[root].l=tr[t].r;tr[t].r=root;
    tr[t].siz=tr[root].siz;update(root);root=t;
}
 
void _insert(int &root,int now){
    if(root==0){
        root=++size;tr[root].siz=tr[root].cnt=1;tr[root].val=now;tr[root].rnd=rand();return;
    }
    tr[root].siz++;
    if(tr[root].val==now) tr[root].cnt++;
    else if(tr[root].val>now){
        _insert(tr[root].l,now);
        if(tr[root].rnd>tr[tr[root].l].rnd) rturn(root);
    }
    else{
        _insert(tr[root].r,now);
        if(tr[root].rnd>tr[tr[root].r].rnd) lturn(root);
    }
}
 
void _del(int &root,int now){
    if(root==0) return;
    if(tr[root].val==now){
        if(tr[root].cnt>1){tr[root].cnt--;tr[root].siz--;}
        else if(tr[root].l==0 || tr[root].r==0) root=tr[root].l+tr[root].r;
        else if(tr[tr[root].r].rnd>tr[tr[root].l].rnd) rturn(root),_del(root,now);
        else lturn(root),_del(root,now);
    }
    else if(tr[root].val>now) tr[root].siz--,_del(tr[root].l,now);
    else tr[root].siz--,_del(tr[root].r,now);
}
 
int query_xpm(int root,int now){
    if(root==0) return 0;
    if(tr[root].val==now) return tr[tr[root].l].siz+1;
    if(tr[root].val<now) return tr[tr[root].l].siz+tr[root].cnt+query_xpm(tr[root].r,now);
    else return query_xpm(tr[root].l,now);
}
 
int query_pmx(int root,int now){
    if(root==0) return 0;
    if(tr[tr[root].l].siz<now && tr[root].cnt+tr[tr[root].l].siz>=now) return tr[root].val;
    else if(now<=tr[tr[root].l].siz) return query_pmx(tr[root].l,now);
    else return query_pmx(tr[root].r,now-tr[tr[root].l].siz-tr[root].cnt);
}
 
int query_qq(int root,int now){
    if(root==0) return -2e9;
    if(tr[root].val>=now) return query_qq(tr[root].l,now);
    else return max(tr[root].val,query_qq(tr[root].r,now));
}
 
int query_hj(int root,int now){
    if(root==0) return 2e9;
    if(tr[root].val<=now) return query_hj(tr[root].r,now);
    else return min(tr[root].val,query_hj(tr[root].l,now));
}
 
int main(){
    n=read();
    int flag,x;
    for(int i=1;i<=n;i++){
        flag=read();x=read();
        if(flag==1) _insert(rt,x);
        if(flag==2) _del(rt,x);
        if(flag==3) printf("%ld\n",query_xpm(rt,x));
        if(flag==4) printf("%ld\n",query_pmx(rt,x));
        if(flag==5) printf("%ld\n",query_qq(rt,x));
        if(flag==6) printf("%ld\n",query_hj(rt,x));
    }
    return 0;
}
复制代码

然后是splay

复制代码
#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
#define maxn 1000005
int tr[maxn][2],f[maxn],siz[maxn],cnt[maxn],key[maxn];
int sz,rt;
inline void clear(int x){tr[x][0]=tr[x][1]=f[x]=siz[x]=cnt[x]=key[x]=0;}
inline void updata(int x){siz[x]=cnt[x]+siz[tr[x][0]]+siz[tr[x][1]];}
inline void rotate(int x){
    int old=f[x],oldf=f[old],whicx=(tr[f[x]][1]==x);
    tr[old][whicx]=tr[x][whicx^1];tr[x][whicx^1]=old;
    f[tr[old][whicx]]=old;f[old]=x;f[x]=oldf;
    if(oldf)tr[oldf][tr[oldf][1]==old]=x;
    updata(old);updata(x);
}
inline void splay(int x){
    for(int fa;fa=f[x];rotate(x))
        if(f[fa])rotate((tr[f[x]][1]==x)==(tr[f[fa]][1]==fa)?fa:x);
    rt=x;
}
inline void insert(int x){
    if(rt==0){sz++;tr[sz][0]=tr[sz][1]=f[sz]=0;rt=sz;siz[sz]=cnt[sz]=1;key[sz]=x;return;}
    int now=rt,fa=0;
    while(1){
        if(x==key[now]){cnt[now]++,updata(now),updata(fa),splay(now);break;}
        fa=now;now=tr[now][key[now]<x];
        if(now==0){
            sz++;
            tr[sz][0]=tr[sz][1]=0,f[sz]=fa;key[sz]=x;
            siz[sz]=cnt[sz]=1;tr[fa][key[fa]<x]=sz;
            updata(fa);splay(sz);
            break;
        }
    }
}
inline int find(int x){
    int now=rt,ans=0;
    while(1){
        if(x<key[now])now=tr[now][0];
        else{
            ans+=(tr[now][0]?siz[tr[now][0]]:0);
            if(x==key[now]){splay(now);return ans+1;}
            ans+=cnt[now];now=tr[now][1];
        }
    }
}
inline int findx(int x){
    int now=rt;
    while(1){
        if(tr[now][0]&&x<=siz[tr[now][0]])now=tr[now][0];
        else{
            int tmp=(tr[now][0]?siz[tr[now][0]]:0)+cnt[now];
            if(x<=tmp)return key[now];
            x-=tmp;now=tr[now][1];
        }
    }
}
inline int pre(){int now=tr[rt][0];while(tr[now][1])now=tr[now][1];return now;}
inline int next(){int now=tr[rt][1];while(tr[now][0])now=tr[now][0];return now;}
inline void del(int x){
    int wher=find(x);
    if(cnt[rt]>1){cnt[rt]--;updata(rt);return;}
    if(!tr[rt][0]&&!tr[rt][1]){clear(rt),rt=0;return;}
    if(!tr[rt][0]){int oldrt=rt;rt=tr[rt][1],f[rt]=0,clear(oldrt);return;}
    else if(!tr[rt][1]){int oldrt=rt;rt=tr[rt][0],f[rt]=0,clear(oldrt);return;}
    int lefb=pre(),oldrt=rt;
    splay(lefb);
    tr[rt][1]=tr[oldrt][1];f[tr[oldrt][1]]=rt;
    clear(oldrt);
    updata(rt); 
}
int main(){
    int n,opt,x;
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        scanf("%d%d",&opt,&x);
        switch(opt){
            case 1:insert(x);break;
            case 2:del(x);break;
            case 3:printf("%d\n",find(x));break;
            case 4:printf("%d\n",findx(x));break;
            case 5:insert(x);printf("%d\n",key[pre()]);del(x);break;
            case 6:insert(x);printf("%d\n",key[next()]);del(x);break;
        }
    }
}
复制代码

 

 3.无旋treap

复制代码
复制代码
/**************************************************************
    Problem: 3224
    Language: C++
    Result: Accepted
    Time:432 ms
    Memory:2780 kb
****************************************************************/
 
#include<cstdio>
#include<cctype>
#include<algorithm>
#define mp make_pair<int,int>
using namespace std;
int cnt,rt,n;
typedef pair<int,int>par;
struct data{int l,r,key,data,siz;}tr[100002];
 
void read(int &x){
    char ch=getchar();x=0;int f=1;
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
    x*=f;
}
 
void updata(int now){
    tr[now].siz=tr[tr[now].l].siz+tr[tr[now].r].siz+1;
}
 
int rank(int x){
    int k=rt,sum=0,tmp=(int)1e9;
    while(k){
        if(tr[k].data==x)tmp=min(tmp,sum+tr[tr[k].l].siz+1);
        if(tr[k].data<x)sum+=tr[tr[k].l].siz+1,k=tr[k].r;
        else k=tr[k].l;
    }
    return tmp==(int)1e9?sum:tmp;
}
 
par split(int x,int k){
    if(k==0)return mp(0,x);
    int ls=tr[x].l,rs=tr[x].r;
    if(k==tr[ls].siz)return tr[x].l=0,updata(x),mp(ls,x);
    if(k==tr[ls].siz+1)return tr[x].r=0,updata(x),mp(x,rs);
    if(k<tr[ls].siz){
        par tmp=split(ls,k);
        return tr[x].l=tmp.second,updata(x),mp(tmp.first,x);
    }
    par tmp=split(rs,k-tr[ls].siz-1);
    return tr[x].r=tmp.first,updata(x),mp(x,tmp.second);
}
 
int merge(int a,int b){
    if(a==0||b==0)return a+b;
    if(tr[a].key<tr[b].key)return tr[a].r=merge(tr[a].r,b),updata(a),a;
    else return tr[b].l=merge(a,tr[b].l),updata(b),b;
}
 
void insert(int x){
    int k=rank(x);
    par tmp=split(rt,k);
    tr[++cnt].data=x;
    tr[cnt].key=rand();
    tr[cnt].siz=1;
    rt=merge(tmp.first,cnt);
    rt=merge(rt,tmp.second);
}
 
void del(int x){
    int k=rank(x);
    par tmp=split(rt,k);
    par tmp2=split(tmp.first,k-1);
    rt=merge(tmp2.first,tmp.second);
}
 
int askrank(int x,int k){
    while(true){
        if(tr[tr[x].l].siz+1==k)return tr[x].data;
        if(tr[tr[x].l].siz<k){k-=(tr[tr[x].l].siz+1);x=tr[x].r;}else x=tr[x].l;
    }
}
 
int pre(int now,int x){
    int ans=-(int)1e9;
    while(now){
        if(tr[now].data<x)ans=max(ans,tr[now].data),now=tr[now].r;
        else now=tr[now].l;
    }
    return ans;
}
 
int nex(int now,int x){
    int ans=(int)2e9;
    while(now){
        if(tr[now].data>x)ans=min(ans,tr[now].data),now=tr[now].l;
        else now=tr[now].r;
    }
    return ans;
}
 
int main(){
    read(n);
    for(int i=1;i<=n;i++){
        int opt,x;
        read(opt);read(x);
        switch(opt){
            case 1:insert(x);break;
            case 2:del(x);break;
            case 3:printf("%d\n",rank(x));break;
            case 4:printf("%d\n",askrank(rt,x));break;
            case 5:printf("%d\n",pre(rt,x));break;
            case 6:printf("%d\n",nex(rt,x));break;
        }
    }
}
复制代码

 4.SBT

复制代码
/**************************************************************
    Problem: 3224
    Language: C++
    Result: Accepted
    Time:352 ms
    Memory:2400 kb
****************************************************************/
 
#include<cstdio>
#include<cctype>
#include<cstring>
#include<algorithm>
#define maxn 100001
using namespace std;
int n;
 
void read(int &x){
    char ch=getchar();x=0;int f=1;
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
    x*=f;
}
struct SBT{
    int rt,cnt;
    int key[maxn],siz[maxn],ls[maxn],rs[maxn];
    void clear(){
        rt=0;cnt=0;
        memset(key,0,sizeof(key));
        memset(siz,0,sizeof(siz));
        memset(ls,0,sizeof(ls));
        memset(rs,0,sizeof(rs));
    }
    void zig(int &p){
        int k=rs[p];
        rs[p]=ls[k];
        ls[k]=p;
        siz[k]=siz[p];
        siz[p]=siz[ls[p]]+siz[rs[p]]+1;
        p=k;
    }
    void zag(int &p){
        int k=ls[p];
        ls[p]=rs[k];
        rs[k]=p;
        siz[k]=siz[p];
        siz[p]=siz[ls[p]]+siz[rs[p]]+1;
        p=k;
    }
    void maintain(int &p,bool flag){
        if(!flag){
            if(siz[ls[ls[p]]]>siz[rs[p]])zag(p);
            else{
                if(siz[rs[ls[p]]]>siz[rs[p]]){
                    zig(ls[p]);zag(p);
                }else return;
            }
        }else{
            if(siz[rs[rs[p]]]>siz[ls[p]])zig(p);
            else{
                if(siz[ls[rs[p]]]>siz[ls[p]]){
                    zag(rs[p]);
                    zig(p);
                }else return;
            }
        }
        maintain(ls[p],false);
        maintain(rs[p],true);
        maintain(p,true);
        maintain(p,false);
    }
    void insert(int &p,int x){
        if(!p){
            p=++cnt;key[p]=x;siz[p]=1;return;
        }
        siz[p]++;
        if(x<key[p])insert(ls[p],x);else insert(rs[p],x);
        maintain(p,x>=key[p]);
    }
    int erase(int &p,int x){
        siz[p]--;int tmp;
        if(x==key[p] ||(x<key[p] && !ls[p])||(x>key[p] && !rs[p])){
            tmp=key[p];
            if(!ls[p] || !rs[p])p=ls[p]+rs[p];
            else key[p]=erase(ls[p],key[p]+1);
            return tmp;
        }
        if(x<key[p])tmp=erase(ls[p],x);else tmp=erase(rs[p],x);
        return tmp;
    }
    int rank(int &p,int x){
        if(!p)return 1;int tmp=0;
        if(x<=key[p])tmp=rank(ls[p],x);
        else tmp=siz[ls[p]]+1+rank(rs[p],x);
        return tmp;
    }
    int askrank(int &p,int x){
        if(x==siz[ls[p]]+1)return key[p];
        if(x<=siz[ls[p]])return askrank(ls[p],x);
        else return askrank(rs[p],x-1-siz[ls[p]]);
    }
    int pre(int &p,int x){
        if(!p)return x;int tmp;
        if(x<=key[p])tmp=pre(ls[p],x);
        else{tmp=pre(rs[p],x);if(tmp==x)tmp=key[p];}
        return tmp;
    }
    int nex(int &p,int x){
        if(!p)return x;int tmp;
        if(x>=key[p])tmp=nex(rs[p],x);
        else{tmp=nex(ls[p],x);if(tmp==x)tmp=key[p];}
        return tmp;
    }
}T;
int main(){
    read(n);
    T.clear();
    int &rt=T.rt=0;
    while(n--){
        int opt,x;
        read(opt);read(x);
        switch(opt){
            case 1:T.insert(rt,x);break;
            case 2:T.erase(rt,x);break;
            case 3:printf("%d\n",T.rank(rt,x));break;
            case 4:printf("%d\n",T.askrank(rt,x));break;
            case 5:printf("%d\n",T.pre(rt,x));break;
            case 6:printf("%d\n",T.nex(rt,x));break;
        }
    }
}
复制代码

 5.替罪羊树

复制代码
/**************************************************************
    Problem: 3224
    Language: C++
    Result: Accepted
    Time:260 ms
    Memory:5512 kb
****************************************************************/
 
#include<cctype>
#include<cstdio>
#include<algorithm>
#define maxn 200005
#define bz 0.75
using namespace std;
int n,cnt,rt,cur[maxn],sum;
struct data{int son[2],fa,siz,val;}tr[maxn];
inline int read(){
  int x=0,f=1;char ch=getchar();
  while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
  while(isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
  return f*x;
}
inline bool balance(int x){
  return (double)tr[x].siz*bz>=(double)tr[tr[x].son[0]].siz && (double) tr[x].siz*bz>=(double)tr[tr[x].son[1]].siz;
} 
inline void recycle(int x){
  if(tr[x].son[0])recycle(tr[x].son[0]);cur[++sum]=x;if(tr[x].son[1])recycle(tr[x].son[1]);
}
inline int build(int l,int r){ 
  if(l>r) return 0;
  int mid=(l+r)>>1,now=cur[mid];
  tr[tr[now].son[0]=build(l,mid-1)].fa=now;
  tr[tr[now].son[1]=build(mid+1,r)].fa=now;
  tr[now].siz=tr[tr[now].son[0]].siz+tr[tr[now].son[1]].siz+1;
  return now;
}
inline void rebuild(int x){
  sum=0;recycle(x);
  int fa=tr[x].fa,whicx=tr[tr[x].fa].son[1]==x;
  int cur=build(1,sum);
  tr[tr[fa].son[whicx]=cur].fa=fa;
  if(x==rt)rt=cur;
}
inline void insert(int x){
  int now=rt,cur=++cnt;
  tr[cur].siz=1,tr[cur].val=x;
  while(1){
    tr[now].siz++;
    bool whicx=(x>=tr[now].val);
    if(tr[now].son[whicx])now=tr[now].son[whicx];
    else{
      tr[tr[now].son[whicx]=cur].fa=now;break;
    }
  }
  int flag=0;
  for(int i=cur;i;i=tr[i].fa)if(!balance(i))flag=i;
  if(flag)rebuild(flag);
}
inline int get_num(int x){
  int now=rt;
  while(1){
    if(tr[now].val==x) return now;
    else now=tr[now].son[tr[now].val<x];
  }
}
  
inline void del(int x){
  if(tr[x].son[0] && tr[x].son[1]){
    int cur=tr[x].son[0];
    while(tr[cur].son[1])cur=tr[cur].son[1];
    tr[x].val=tr[cur].val;x=cur;
  }
  int whic=(tr[x].son[0])?tr[x].son[0]:tr[x].son[1];
  int k=(tr[tr[x].fa].son[1]==x);
  tr[tr[tr[x].fa].son[k]=whic].fa=tr[x].fa;
  for(int i=tr[x].fa;i;i=tr[i].fa)tr[i].siz--;
  if(x==rt)rt=whic;
}
inline int get_rank(int x){
  int now=rt,ans=0;
  while(now){
    if(tr[now].val<x)ans+=tr[tr[now].son[0]].siz+1,now=tr[now].son[1];
    else now=tr[now].son[0];
  }
  return ans;
}
inline int get_kth(int x){
  int now=rt;
  while(1){
    if(tr[tr[now].son[0]].siz==x-1)return now;
    else if(tr[tr[now].son[0]].siz>=x)now=tr[now].son[0];
    else x-=tr[tr[now].son[0]].siz+1,now=tr[now].son[1];
  }
  return now;
}
inline int get_pre(int x){
  int now=rt,ans=-2e9;
  while(now){
    if(tr[now].val<x)ans=max(ans,tr[now].val),now=tr[now].son[1];
    else now=tr[now].son[0];
  }
  return ans;
}
inline int get_suc(int x){
  int now=rt,ans=2e9;
  while(now){
    if(tr[now].val>x)ans=min(ans,tr[now].val),now=tr[now].son[0];
    else now=tr[now].son[1];
  }
  return ans;
}
int main(){
  cnt=2;rt=1;
  tr[1].val=-2e9,tr[1].siz=2,tr[1].son[1]=2;
  tr[2].val=2e9,tr[2].siz=1,tr[2].fa=1;
  n=read();int typ,x;
  for(int i=1;i<=n;i++){
    typ=read(),x=read();
    switch(typ){
        case 1:insert(x);break;
        case 2:del(get_num(x));break;
        case 3:printf("%d\n",get_rank(x));break;
        case 4:printf("%d\n",tr[get_kth(x+1)].val);break;
        case 5:printf("%d\n",get_pre(x));break;
        case 6:printf("%d\n",get_suc(x));
    }
  }
}
复制代码

 

复制代码
posted @   lnyzo  阅读(154)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示