[bzoj3196]二逼平衡树
您需要写一种数据结构(可参考题目标题),来维护一个有序数列,其中需要提供以下操作:
1.查询k在区间内的排名
2.查询区间内排名为k的值
3.修改某一位值上的数值
4.查询k在区间内的前驱(前驱定义为小于x,且最大的数)
5.查询k在区间内的后继(后继定义为大于x,且最小的数)
树套树;
外层线段树,内层splay;
具体的做法是在线段树的每个节点上建立一颗splay,利用splay维护每个线段树节点上的信息;
线段树一共有logn层,每层的大小是n,空间复杂度是nlogn,时间复杂度是nlognlogn(?);
常数大得惊人,10s的时限,跑了9s9,汗。
实际上在这种不需要树的合并的场合,用treap就行了;
#include<iostream> #include<cstdio> #include<cstdlib> #include<cstring> #include<string> #include<ctime> #include<cmath> #include<set> #include<map> #include<queue> #include<algorithm> #include<iomanip> #include<stack> using namespace std; #define FILE "dealing" #define up(i,j,n) for(int i=(j);i<=(n);i++) #define pii pair<int,int> #define LL int #define mem(f,g) memset(f,g,sizeof(f)) namespace IO{ char buf[1<<15],*fs,*ft; int gc(){return (fs==ft&&(ft=(fs=buf)+fread(buf,1,1<<15,stdin),fs==ft))?-1:*fs++;} int read(){ int ch=gc(),f=0,x=0; while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=gc();} while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=gc();} return f?-x:x; } int readint(){ int ch=getchar(),f=0,x=0; while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();} while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();} return f?-x:x; } }using namespace IO; const int maxn=4001000,inf=1000000000; int n,m; int a[maxn]; int c[maxn][2],v[maxn],siz[maxn],fa[maxn],root[maxn],t[maxn],cnt; void updata(int x){siz[x]=siz[c[x][0]]+siz[c[x][1]]+t[x];} void rotate(int x){ int k=fa[x]; int d=(c[k][1]==x); fa[x]=fa[k];fa[k]=x;fa[c[x][d^1]]=k; c[k][d]=c[x][d^1];c[x][d^1]=k; if(fa[x])c[fa[x]][c[fa[x]][1]==k]=x; updata(k);updata(x); } void splay(int x,int s,int rt){ if(fa[x]==s)return; while(fa[x]!=s){ if(fa[fa[x]]==s)rotate(x); else { int y=fa[x],z=fa[y]; if(c[y][1]==x^c[z][1]==y)rotate(x); else rotate(y); rotate(x); } } if(!s)root[rt]=x; } void insert(int key,int rt){ if(!root[rt]){ root[rt]=++cnt,t[cnt]=siz[cnt]=1,v[cnt]=key,fa[cnt]=0; return; } int now=root[rt],y; while(now){ if(key==v[now]){splay(now,0,rt),t[now]++,updata(now);return;} y=now,now=c[now][key>v[now]]; } now=++cnt;t[now]=siz[now]=1;fa[now]=y;c[y][key>v[y]]=now;v[now]=key; splay(now,0,rt); } void build(int l,int r,int rt){ if(l>r)return; insert(-inf,rt);//哨兵 insert(inf<<1,rt); up(i,l,r)insert(a[i],rt); int mid=(l+r)>>1; if(l!=r){ build(l,mid,rt<<1); build(mid+1,r,rt<<1|1); } } int x,y,sum,key,pos; int find(int key,int rt){//在rt的子树中寻找最小的大于key的节点 int now=root[rt],id=0; while(now){ if(v[now]>key&&(v[now]<v[id]||!id))id=now; now=c[now][key>=v[now]]; } splay(id,0,rt); return id; } int findpre(int key,int rt){//在rt的子树中寻找最大的小于key的节点 int now=root[rt],id=0; while(now){ if(v[now]<key&&(v[now]>v[id]||!id))id=now; now=c[now][key>v[now]]; } splay(id,0,rt); return id; } void query_rank(int l,int r,int rt){ if(l>y||r<x)return; if(x<=l&&r<=y){ int now=find(key,rt); sum+=siz[c[now][0]]-1; return; } int mid=(l+r)>>1; query_rank(l,mid,rt<<1); query_rank(mid+1,r,rt<<1|1); } int getrank(int k,int l,int r){ key=k;sum=0;x=l,y=r; query_rank(1,n,1); return sum; } int getK(int k,int l,int r){ int x=0,y=inf,mid; while(x+1!=y){ mid=(x+y)>>1; if(getrank(mid,l,r)<=k)x=mid; else y=mid; } if(getrank(x,l,r)<=k&&getrank(y,l,r)>k)return y; return x; } int getl(int x){while(c[x][0])x=c[x][0];return x;} int getr(int x){while(c[x][1])x=c[x][1];return x;} void delet(int k,int rt){ int x=find(k-1,rt); splay(x,0,rt); int l=c[x][0],r=c[x][1]; l=getr(l);r=getl(r); splay(l,0,rt);splay(r,l,rt); if(t[x]>1){t[x]--;updata(x);} else c[r][0]=0; updata(r),updata(l); } void Change(int l,int r,int rt){ if(l>pos||r<pos)return; delet(a[pos],rt); insert(key,rt); int mid=(l+r)>>1; if(l!=r){ Change(l,mid,rt<<1); Change(mid+1,r,rt<<1|1); } return; } void change(int p,int k){ pos=p,key=k; Change(1,n,1); } void query_min(int l,int r,int rt){ if(l>y||r<x)return; if(l>=x&&r<=y){ int y=findpre(key,rt); if(v[y]>sum)sum=v[y]; return; } int mid=(l+r)>>1; query_min(l,mid,rt<<1); query_min(mid+1,r,rt<<1|1); } int getleft(int k,int l,int r){ key=k;x=l,y=r;sum=-inf; query_min(1,n,1); return sum; } void query_max(int l,int r,int rt){ if(l>y||r<x)return; if(l>=x&&r<=y){ int y=find(key,rt); if(v[y]<sum)sum=v[y]; return; } int mid=(l+r)>>1; query_max(l,mid,rt<<1); query_max(mid+1,r,rt<<1|1); } int getright(int k,int l,int r){ key=k;x=l,y=r;sum=inf; query_max(1,n,1); return sum; } int main(){ freopen(FILE".in","r",stdin); freopen(FILE".out","w",stdout); n=read(),m=read(); up(i,1,n)a[i]=read(); build(1,n,1); int ch,l,r,k,pos; while(m--){ ch=read(); if(ch!=3)l=read(),r=read(),k=read(); else pos=read(),k=read(); switch (ch){ case 1:printf("%d\n",getrank(k-1,l,r)+1);break; case 2:printf("%d\n",getK(k-1,l,r));break; case 3:change(pos,k);a[pos]=k;break; case 4:printf("%d\n",getleft(k,l,r));break; case 5:printf("%d\n",getright(k,l,r));break; } } return 0; }
treap套树状数组,可能由于后两个操作由logn变成log^2n的缘故,时间上的进步并不明显(bzoj上的srand不能用,但在校内oj上测跑得比原来快了1s);
#include<iostream> #include<cstdio> #include<cstdlib> #include<cstring> #include<string> #include<ctime> #include<cmath> #include<set> #include<map> #include<queue> #include<algorithm> #include<iomanip> #include<stack> using namespace std; #define FILE "dealing" #define up(i,j,n) for(int i=(j);i<=(n);i++) #define pii pair<int,int> #define LL int #define mem(f,g) memset(f,g,sizeof(f)) namespace IO{ char buf[1<<15],*fs,*ft; int gc(){return (fs==ft&&(ft=(fs=buf)+fread(buf,1,1<<15,stdin),fs==ft))?-1:*fs++;} int read(){ int ch=gc(),f=0,x=0; while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=gc();} while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=gc();} return f?-x:x; } int readint(){ int ch=getchar(),f=0,x=0; while(ch<'0'||ch>'9'){if(ch=='-')f=1;ch=getchar();} while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();} return f?-x:x; } }using namespace IO; const int maxn=2001000,inf=1000000000; int n,m,C[maxn],a[maxn]; int lowbit(int x){return x&-x;} namespace treap{ int c[maxn][2],v[maxn],siz[maxn],t[maxn],cnt=0,w[maxn]; void updata(int x){siz[x]=siz[c[x][0]]+siz[c[x][1]]+t[x];} void rotate(int &o,int d){int k=c[o][d];c[o][d]=c[k][d^1];c[k][d^1]=o;updata(o);updata(k);o=k;} void insert(int& o,int key){ if(!o){o=++cnt;siz[o]=t[o]=1;v[o]=key;w[o]=rand();return;} if(v[o]==key){t[o]++;updata(o);return;} int d=(key>v[o]); insert(c[o][d],key); updata(o); if(w[c[o][d]]>w[o])rotate(o,d); } void delet(int& o,int key){ if(v[o]==key){ if(t[o]>1){t[o]--;updata(o);return;} if(!c[o][0])o=c[o][1]; else if(!c[o][1])o=c[o][0]; else { int d=(w[c[o][1]]>w[c[o][0]]); rotate(o,d); delet(c[o][d^1],key); updata(o); } } else delet(c[o][key>v[o]],key),updata(o); } int getrank(int o,int key){//在o所在的treap内有多少点的值小于等于key int ans=0; while(o){ if(key>=v[o])ans+=siz[c[o][0]]+t[o],o=c[o][1]; else o=c[o][0]; } return ans; } }; namespace Bit{ int getrank(int key,int l,int r){//返回在[l,r]内有多少点值小于等于key int ans=0;l--; while(r)ans+=treap::getrank(C[r],key),r-=lowbit(r); while(l)ans-=treap::getrank(C[l],key),l-=lowbit(l); return ans; } int getK(int k,int l,int r){ int left=0,right=inf,mid; while(left+1<right){ mid=(left+right)>>1; if(getrank(mid,l,r)>k)right=mid; else left=mid; } if(getrank(left,l,r)<=k&&getrank(right,l,r)>=k)return right; return left; } void change(int pos,int key){ int i=pos; while(pos<=n){ treap::delet(C[pos],a[i]); treap::insert(C[pos],key); pos+=lowbit(pos); } } int getleft(int key,int l,int r){ int k=getrank(key-1,l,r); int left=0,right=key,mid; while(left+1<right){ mid=(left+right)>>1; if(getrank(mid,l,r)==k)right=mid; else left=mid; } if(getrank(right,l,r)==k&&getrank(left,l,r)<k)return right; else return left; } int getright(int key,int l,int r){ int k=getrank(key,l,r); int left=key,right=inf,mid; while(left+1<right){ mid=(left+right)>>1; if(getrank(mid,l,r)==k)left=mid; else right=mid; } if(getrank(right,l,r)>k&&getrank(left,l,r)==k)return right; else return left; } }; int main(){ int __size__ = 20 << 20; // 20MB char *__p__ = (char*)malloc(__size__) + __size__; __asm__("movl %0, %%esp\n" :: "r"(__p__)); n=read(),m=read(); srand((int)time(NULL)); up(i,1,n)a[i]=read(); up(i,1,n){ int k=i; while(k<=n)treap::insert(C[k],a[i]),k+=lowbit(k); } int ch,l,r,k,pos; int cnt=0; while(m--){ ch=read(); if(ch!=3)l=read(),r=read(),k=read(),cnt++; else pos=read(),k=read(); switch (ch){ case 1:printf("%d\n",Bit::getrank(k-1,l,r)+1);break; case 2:printf("%d\n",Bit::getK(k-1,l,r));break; case 3:Bit::change(pos,k);a[pos]=k;break; case 4:printf("%d\n",Bit::getleft(k,l,r));break; case 5:printf("%d\n",Bit::getright(k,l,r));break; } } return 0; }