【模板】二逼平衡树(线段树+平衡树)

题目描述

您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:

1.查询k在区间内的排名

2.查询区间内排名为k的值

3.修改某一位值上的数值

4.查询k在区间内的前驱(前驱定义为严格小于x,且最大的数,若不存在输出-2147483647)

5.查询k在区间内的后继(后继定义为严格大于x,且最小的数,若不存在输出2147483647)

n,m5104 保证有序序列所有值在任何时刻满足[0,108]

题解

这些操作都是平衡树的常见操作,考虑怎么维护区间。

用线段树即可,对序列建出线段树,每个节点维护一颗splay,splay维护区间的数。

1.只要查询区间有多少数比他小即可

2.不好直接查询,就只好二分值是多少再调用1判断

3.在一条链上删除和插入

4.对所有小区间查出来的前驱取max

5.对所有查出来的后继取min

代码写的函数有点多,很难受,不过还是按照自己的思路才添加的

#include<bits/stdc++.h>
using namespace std;
const int maxn=50005;
const int maxm=2000005;
const int oo=2147483647;
int n,m,o,cnt,num,ls[maxn<<1],rs[maxn<<1];
int a[maxn];
int root[maxn<<1];
struct Splay{
    int fa,s[2],size,tag;
    int val;
}tr[maxm];

template<class T>inline void read(T &x){
    x=0;char ch=getchar();
    while(!isdigit(ch)) ch=getchar();
    while(isdigit(ch)) {x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
}

void update(int x){
    tr[x].size=tr[tr[x].s[0]].size+tr[tr[x].s[1]].size+tr[x].tag;
}

int get(int x){
    return tr[tr[x].fa].s[1]==x;
}

void connect(int x,int y,int d){
    tr[x].fa=y;
    tr[y].s[d]=x;
}

void rotate(int x){
    int f=tr[x].fa,ff=tr[f].fa;
    int d1=get(x),d2=get(f);
    int cs=tr[x].s[d1^1];
    connect(x,ff,d2);
    connect(f,x,d1^1);
    connect(cs,f,d1);
    update(f);
    update(x);
}

void splay(int x,int go,int id){
    if(go==root[id]) root[id]=x;
    go=tr[go].fa;
    while(tr[x].fa!=go){
        int f=tr[x].fa;
        if(tr[f].fa==go) rotate(x);
        else if(get(f)==get(x)) {rotate(f);rotate(x);}
        else {rotate(x);rotate(x);}
    }
}

void insert(int val,int id){
    int now=root[id];
    if(!now){
        root[id]=++num;
        tr[num]=(Splay){0,{0,0},1,1,val};
        return ;
    }
    while(now){
        tr[now].size++;
        if(tr[now].val==val){
            tr[now].tag++;
            break;
        }
        int d=val>tr[now].val;
        if(!tr[now].s[d]){
            tr[now].s[d]=++num;
            tr[num]=(Splay){now,{0,0},1,1,val};
            now=num;
            break;
        }
        now=tr[now].s[d];
    }
    splay(now,root[id],id);
}

void modify(int &rt,int l,int r,int pos,int val){
    if(!rt) {
        rt=++cnt;
        insert(oo,rt);
        insert(-oo,rt);
    }
    insert(val,rt);
    if(l==r) return ;
    int mid=(l+r)>>1;
    if(pos<=mid) modify(ls[rt],l,mid,pos,val);
    else modify(rs[rt],mid+1,r,pos,val);
}

int query(int id,int val){
    int now=root[id],ret=0;
    while(now){
        if(tr[now].val==val) return ret+tr[tr[now].s[0]].size;
        else if(tr[now].val<val){
            ret+=tr[tr[now].s[0]].size+tr[now].tag;
            now=tr[now].s[1];
        }
        else now=tr[now].s[0];
    }
    return ret;
}

int seq_queryrank(int rt,int l,int r,int a_l,int a_r,int val){
    if(a_l<=l&&r<=a_r) return query(rt,val)-1;
    int ret=0,mid=(l+r)>>1;
    if(a_l<=mid) ret+=seq_queryrank(ls[rt],l,mid,a_l,a_r,val);
    if(mid<a_r) ret+=seq_queryrank(rs[rt],mid+1,r,a_l,a_r,val);
    return ret;
}

int querynumber(int l,int r,int k){
    int L=0,R=oo,ret;
    while(L<=R){
        int mid=(L+R)>>1;
        if(seq_queryrank(1,1,n,l,r,mid)+1<=k) ret=mid,L=mid+1;
        else R=mid-1;
    }
    return ret;
}

int findval(int id,int val){//查找值为x的是哪个 
    int now=root[id];
    while(1){
        if(tr[now].val==val) return now;
        else if(tr[now].val<val) now=tr[now].s[1];
        else now=tr[now].s[0];
    }
}

int findrank(int id,int k){//查找排名为x是哪个 
    int now=root[id];
    while(1){
        if(tr[tr[now].s[0]].size>=k) {now=tr[now].s[0];continue;}
        k-=tr[tr[now].s[0]].size;
        if(k<=tr[now].tag) return now;
        k-=tr[now].tag;
        now=tr[now].s[1];
    }
}

void dele(int id,int val){
    int now=findval(id,val);
    splay(now,root[id],id);
    if(tr[now].tag>1) {tr[now].tag--;tr[now].size--;return ;}
    int  k=tr[tr[now].s[0]].size,x=findrank(id,k),y=findrank(id,k+tr[now].tag+1);
    splay(x,root[id],id);
    splay(y,tr[x].s[1],id);
    tr[y].s[0]=0;
    update(y);update(x);
}

void get_dele(int rt,int l,int r,int pos,int val){
    dele(rt,val);
    if(l==r) return ;
    int mid=(l+r)>>1;
    if(pos<=mid) get_dele(ls[rt],l,mid,pos,val);
    else get_dele(rs[rt],mid+1,r,pos,val);
}

int querypre(int id,int val){
    int now=root[id],ans=-oo;
    while(now){
        if(tr[now].val<val){
            ans=max(ans,tr[now].val);
            now=tr[now].s[1];
        }
        else now=tr[now].s[0];
    }
    return ans;
}

int seg_querypre(int rt,int l,int r,int a_l,int a_r,int val){
    if(a_l<=l&&r<=a_r) return querypre(rt,val);
    int ans=-oo,mid=(l+r)>>1;
    if(a_l<=mid) ans=max(ans,seg_querypre(ls[rt],l,mid,a_l,a_r,val));
    if(mid<a_r) ans=max(ans,seg_querypre(rs[rt],mid+1,r,a_l,a_r,val));
    return ans;
}

int querynext(int id,int val){
    int now=root[id],ans=oo;
    while(now){
        if(tr[now].val>val){
            ans=min(ans,tr[now].val);
            now=tr[now].s[0];
        }
        else now=tr[now].s[1];
    }
    return ans;
}

int seg_querynext(int rt,int l,int r,int a_l,int a_r,int val){
    if(a_l<=l&&r<=a_r) return querynext(rt,val);
    int ans=oo,mid=(l+r)>>1;
    if(a_l<=mid) ans=min(ans,seg_querynext(ls[rt],l,mid,a_l,a_r,val));
    if(mid<a_r) ans=min(ans,seg_querynext(rs[rt],mid+1,r,a_l,a_r,val));
    return ans;
}

void debug(int x){
    if(tr[x].s[0]) debug(tr[x].s[0]);;
    printf("%d ",tr[x].val);
    if(tr[x].s[1]) debug(tr[x].s[1]);
}

int main(){
    read(n);read(m);
    for(int i=1;i<=n;i++){
        read(a[i]);
        modify(o,1,n,i,a[i]);
    }
    for(int i=1;i<=m;i++){
        int opt;read(opt);
        if(opt==1){
            int l,r,val;
            read(l);read(r);read(val);
            printf("%d\n",seq_queryrank(1,1,n,l,r,val)+1); 
        }
        else if(opt==2){
            int l,r,k;
            read(l);read(r);read(k);
            printf("%d\n",querynumber(l,r,k));
        }
        else if(opt==3){
            int pos,val;
            read(pos);read(val);
            get_dele(1,1,n,pos,a[pos]);
            modify(o,1,n,pos,a[pos]=val);
        }
        else if(opt==4){
            int l,r,val;
            read(l);read(r);read(val);
            printf("%d\n",seg_querypre(1,1,n,l,r,val));
        }
        else {
            int l,r,val;
            read(l);read(r);read(val);
            printf("%d\n",seg_querynext(1,1,n,l,r,val));
        }
    }
}
【模板】二逼平衡树(树套树)

当然用树状数组套值域线段树也是可以的,注意l-1这个细节就好

查询前驱就把x的排名p查出来,然后p=1就没有前驱,不然就查询排名是p-1的数。

查询后继,因为可能有很多数等于x,然后他们的排名虽然一样但会占位置,所以查x+1的排名p,如果p是最后一个,注意区间长度是r-l(因为查询输的l-1),就没有后继,不然查询排名为p的数。

#include<bits/stdc++.h>
using namespace std;

const int maxn=50005;
const int maxm=10000005;
const int oo=100000000;
const int cx=2147483647;
int n,m,a[maxn];
int cnt,root[maxn];
int ls[maxm],rs[maxm],size[maxm];

template<class T>inline void read(T &x){
    x=0;int f=0;char ch=getchar();
    while(!isdigit(ch)) {f|=(ch=='-');ch=getchar();}
    while(isdigit(ch)) {x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
    x = f ? -x : x ;
}

void modify(int &rt,int l,int r,int pos,int val){
    if(!rt) rt=++cnt;
    size[rt]+=val;
    if(l==r) return ;
    int mid=(l+r)>>1;
    if(pos<=mid) modify(ls[rt],l,mid,pos,val);
    else modify(rs[rt],mid+1,r,pos,val);
}

void get_modify(int x,int pos,int val){for(;x<=n;x+=x&-x) modify(root[x],0,oo,pos,val);}

int tp1,tp2,s1[maxn],s2[maxn];

int querynumber(int x,int y,int k){
    tp1=tp2=0;
    for(;x;x-=x&-x) s1[++tp1]=root[x];
    for(;y;y-=y&-y) s2[++tp2]=root[y];
    int l=0,r=oo;
    while(1){
        if(l==r) return l;
      int sum=0,mid=(l+r)>>1;
      for(int i=1;i<=tp1;i++) sum-=size[ls[s1[i]]];
      for(int i=1;i<=tp2;i++) sum+=size[ls[s2[i]]];
      if(sum>=k){
          for(int i=1;i<=tp1;i++) s1[i]=ls[s1[i]];
          for(int i=1;i<=tp2;i++) s2[i]=ls[s2[i]];
          r=mid;
    }
      else {
          for(int i=1;i<=tp1;i++) s1[i]=rs[s1[i]];
          for(int i=1;i<=tp2;i++) s2[i]=rs[s2[i]];
          l=mid+1;k-=sum;
    }
    }
}

int queryrank(int x,int y,int pos){
    tp1=tp2=1;
    for(;x;x-=x&-x) s1[++tp1]=root[x];
    for(;y;y-=y&-y) s2[++tp2]=root[y];
    int l=0,r=oo,ret=0;
    while(1){
        if(l==r) return ret+1;
      int mid=(l+r)>>1;
      if(pos<=mid){
          for(int i=1;i<=tp1;i++) s1[i]=ls[s1[i]];
          for(int i=1;i<=tp2;i++) s2[i]=ls[s2[i]];
          r=mid;
    }
      else {
          for(int i=1;i<=tp1;i++) ret-=size[ls[s1[i]]],s1[i]=rs[s1[i]];
          for(int i=1;i<=tp2;i++) ret+=size[ls[s2[i]]],s2[i]=rs[s2[i]];
          l=mid+1;
    }
  }
}

int querypre(int l,int r,int pos){
    int p=queryrank(l,r,pos);
    if(p==1) return -cx;
    return querynumber(l,r,p-1);
}

int querynext(int l,int r,int pos){
    int p=queryrank(l,r,pos+1);
    if(p>r-l) return cx;
    return querynumber(l,r,p);
}

void print(int x){
    if(x<0) putchar('-'),x=-x;
  if(x>9) print(x/10);
  putchar(x%10+'0');
}

int main(){
    read(n);read(m);;
    for(int i=1;i<=n;i++){
        read(a[i]);
        get_modify(i,a[i],1);
    }
    for(int i=1;i<=m;i++){
        int opt;read(opt);
        if(opt==1){
            int l,r,val;
            read(l);read(r);read(val);
            print(queryrank(l-1,r,val)),putchar(10); 
        }
        else if(opt==2){
            int l,r,k;
            read(l);read(r);read(k);
            print(querynumber(l-1,r,k)),putchar(10);
        }
        else if(opt==3){
            int pos,val;
            read(pos);read(val);
            get_modify(pos,a[pos],-1);
            get_modify(pos,a[pos]=val,1);
        }
        else if(opt==4){
            int l,r,val;
            read(l);read(r);read(val);
            print(querypre(l-1,r,val)),putchar(10);
        }
        else {
            int l,r,val;
            read(l);read(r);read(val);
            print(querynext(l-1,r,val)),putchar(10);
        }
    }
}
树状数组+值域线段树

 

 对于前驱为什么不想后继那么查询,可以想一下,一开始强迫症搞成一样就错了。

还有为什么后继不能想前驱那么查。

 

 

posted @ 2019-08-25 21:12  _JSQ  阅读(368)  评论(0编辑  收藏  举报