可持久化线段树(主席树)
定义
可持久化线段树是可以保留历史版本的线段树,相当于保留了每次修改后的线段树,并且可以对每次修改后的结果进行查询
主要思想
对于询问历史版本的问题,我们对每次修改都新建一棵线段树
但是如果修改次数特别多,这样肯定是不可行的,这就要用到可持久化线段树了
我们通过观察可以发现,对于线段树的每次修改,不论是单点还是区间,所改变的节点都是log级别的
于是我们所建的线段树只新建那log级别改变的节点,不变的节点直接指向历史版本就行了
具体实现
-
建树
和普通线段树基本相同
int build(int L,int R){//建树 int k=size++; if(L==R){sum[k]=a[L-1];return k;} int mid=(L+R)>>1; l[k]=build(L,mid),r[k]=build(mid+1,R); sum[k]=sum[l[k]]+sum[r[k]]; return k; }
-
区间修改
对于所有的修改,都将原节点复制一遍,然后在新建的节点上修改
对于没有修改的子节点就直接指向上个版本的节点即可int modify(int history,int L,int R,int w,int l1,int r1){//区间修改,history是上个版本的当前节点的编号,L,R为当前节点范围,l1,l2为要修改的范围 int k=size++; lazy[k]=lazy[history],sum[k]=sum[history]+1ll*(r1-l1+1)*w,l[k]=l[history],r[k]=r[history];//克隆节点 if(l1<=L&&r1>=R){//整个区间都要被修改,打个标记返回 lazy[k]+=w;return k; } int mid=(L+R)>>1; if(l1<=mid)l[k]=modify(l[history],L,mid,w,l1,min(mid,r1)); if(r1>mid)r[k]=modify(r[history],mid+1,R,w,max(mid+1,l1),r1); return k; } int change(int history,int L,int R,int w,int x){//单点修改 int k=size++; lazy[k]=lazy[history],sum[k]=sum[history]+w,l[k]=l[history],r[k]=r[history];//克隆节点 if(L==R)return k;//如果是叶子节点就直接返回 int mid=(L+R)>>1; if(x<=mid)l[k]=change(l[history],L,mid,w,x); else r[k]=change(r[history],mid+1,R,w,x); return k; }
-
区间查询
和普通线段树区间查询一样,只是lazy标记不能下放(因为子节点是共用的),而是统计lazy标记造成的影响
int query(int x,int L,int R,int l1,int r1){//询问 if(l1<=L&&r1>=R)return sum[x]; int mid=(L+R)>>1; int ans=lazy[x]*(r1-l1+1);//标记不能下放,因为子节点也是共用的,只能统计标记造成的影响 if(l1<=mid)ans+=query(l[x],L,mid,l1,min(mid,r1)); if(r1>mid)ans+=query(r[x],mid+1,R,max(mid+1,l1),r1); return ans; }
模板
-
非动态开点
#include<cstdio> #include<algorithm> using namespace std; #define maxn 100005 int n,m,a[maxn],root[maxn],size,lazy[maxn*70],l[maxn*70],r[maxn*70],sum[maxn*70];//非动态开点则节点要开到n*logn*4 int build(int L,int R){//建树 int k=size++; if(L==R){sum[k]=a[L-1];return k;} int mid=(L+R)>>1; l[k]=build(L,mid),r[k]=build(mid+1,R); sum[k]=sum[l[k]]+sum[r[k]]; return k; } int modify(int history,int L,int R,int w,int l1,int r1){//区间修改,history是上个版本的当前节点的编号,L,R为当前节点范围,l1,l2为要修改的范围 int k=size++; lazy[k]=lazy[history],sum[k]=sum[history]+1ll*(r1-l1+1)*w,l[k]=l[history],r[k]=r[history];//克隆节点 if(l1<=L&&r1>=R){//整个区间都要被修改,打个标记返回 lazy[k]+=w;return k; } int mid=(L+R)>>1; if(l1<=mid)l[k]=modify(l[history],L,mid,w,l1,min(mid,r1)); if(r1>mid)r[k]=modify(r[history],mid+1,R,w,max(mid+1,l1),r1); return k; } int change(int history,int L,int R,int w,int x){//单点修改 int k=size++; lazy[k]=lazy[history],sum[k]=sum[history]+w,l[k]=l[history],r[k]=r[history];//克隆节点 if(L==R)return k;//如果是叶子节点就直接返回 int mid=(L+R)>>1; if(x<=mid)l[k]=change(l[history],L,mid,w,x); else r[k]=change(r[history],mid+1,R,w,x); return k; } int query(int x,int L,int R,int l1,int r1){//询问 if(l1<=L&&r1>=R)return sum[x]; int mid=(L+R)>>1; int ans=lazy[x]*(r1-l1+1);//标记不能下放,因为子节点也是共用的,只能统计标记造成的影响 if(l1<=mid)ans+=query(l[x],L,mid,l1,min(mid,r1)); if(r1>mid)ans+=query(r[x],mid+1,R,max(mid+1,l1),r1); return ans; } int main(){ return 0; }
-
动态开点
#include<cstdio> #include<algorithm> #include<cstring> using namespace std; #define maxn 100005 struct Node{ Node *l,*r; bool l1,r1;//记录左右子树是不是当前节点创建的,如果不delete就不需要 int lazy,sum; Node(){ memset(this,0,sizeof(Node)); } }root[maxn]; int n,m,a[maxn]; void build(Node &x,int L,int R){//建树 if(L==R){x.sum=a[L-1];return;} int mid=(L+R)>>1; build(*(x.l1=1,x.l=new Node),L,mid),build(*(x.r1=1,x.r=new Node),mid+1,R); x.sum=x.l->sum+x.r->sum; } void modify(Node &history,Node &x,int L,int R,int w,int l1,int r1){//区间修改 x=history,x.sum+=(r1-l1+1)*w,x.r1=x.l1=0; if(l1<=L&&r1>=R){ x.lazy+=w;return; } int mid=(L+R)>>1; if(l1<=mid)modify(*history.l,*(x.l1=1,x.l=new Node),L,mid,w,l1,min(mid,r1)); if(r1>mid)modify(*history.r,*(x.r1=1,x.r=new Node),mid+1,R,w,max(mid+1,l1),r1); } void change(Node &history,Node &x,int L,int R,int w,int l1){//单点修改 x=history,x.sum+=w,x.l1=x.r1=0;//克隆节点 if(L==R)return;//如果是叶子节点就直接返回 int mid=(L+R)>>1; if(l1<=mid)change(*history.l,*(x.l1=1,x.l=new Node),L,mid,w,l1); else change(*history.r,*(x.r1=1,x.r=new Node),mid+1,R,w,l1); } int query(Node &x,int L,int R,int l1,int r1){//区间查询 if(l1<=L&&r1>=R)return x.sum; int mid=(L+R)>>1,ans=x.lazy*(r1-l1+1); if(l1<=mid)ans+=query(*x.l,L,mid,l1,min(mid,r1)); if(r1>mid)ans+=query(*x.r,mid+1,R,max(mid+1,l1),r1); return ans; } void remove_node(Node *x){//节点空间释放 if(x->l1)remove_node(x->l); if(x->r1)remove_node(x->r); delete x; } void remove(int n){//线段树空间释放 for(int i=0;i<=n;i++){ if(root[i].l1)remove_node(root[i].l); if(root[i].r1)remove_node(root[i].r); } memset(root,0,sizeof(root)); } int main(){ return 0; }
例题hdu4348.To the moon
#include<cstdio> #include<algorithm> #include<cstring> using namespace std; #define maxn 100005 #define LL long long struct Node{ Node *l,*r; bool l1,r1;//记录左右子树是不是当前节点创建的,如果不delete就不需要 int lazy; LL sum; Node(){ memset(this,0,sizeof(Node)); } }root[maxn]; int n,m,a[maxn]; void build(Node &x,int L,int R){//建树 if(L==R){x.sum=a[L-1];return;} int mid=(L+R)>>1; build(*(x.l1=1,x.l=new Node),L,mid),build(*(x.r1=1,x.r=new Node),mid+1,R); x.sum=x.l->sum+x.r->sum; } void modify(Node &history,Node &x,int L,int R,int w,int l1,int r1){//区间修改 x=history,x.sum+=1ll*(r1-l1+1)*w,x.r1=x.l1=0; if(l1<=L&&r1>=R){ x.lazy+=w;return; } int mid=(L+R)>>1; if(l1<=mid)modify(*history.l,*(x.l1=1,x.l=new Node),L,mid,w,l1,min(mid,r1)); if(r1>mid)modify(*history.r,*(x.r1=1,x.r=new Node),mid+1,R,w,max(mid+1,l1),r1); } LL query(Node &x,int L,int R,int l1,int r1){//区间查询 if(l1<=L&&r1>=R)return x.sum; int mid=(L+R)>>1;LL ans=x.lazy*(r1-l1+1); if(l1<=mid)ans+=query(*x.l,L,mid,l1,min(mid,r1)); if(r1>mid)ans+=query(*x.r,mid+1,R,max(mid+1,l1),r1); return ans; } void remove_node(Node *x){//节点空间释放 if(x->l1)remove_node(x->l); if(x->r1)remove_node(x->r); delete x; } void remove(int n){//线段树空间释放 for(int i=0;i<=n;i++){ if(root[i].l1)remove_node(root[i].l); if(root[i].r1)remove_node(root[i].r); } memset(root,0,sizeof(root)); } void work(){ for(int i=0;i<n;i++)scanf("%d",a+i); build(root[0],1,n); int L,R,d,time=0;char s[2]; for(int i=0;i<m;i++){ scanf("%s",s); if(s[0]=='C'){ scanf("%d%d%d",&L,&R,&d); time++; modify(root[time-1],root[time],1,n,d,L,R); } else if(s[0]=='Q'){ scanf("%d%d",&L,&R); printf("%lld\n",query(root[time],1,n,L,R)); } else if(s[0]=='H'){ scanf("%d%d%d",&L,&R,&d); printf("%lld\n",query(root[d],1,n,L,R)); } else{ scanf("%d",&d); for(;time>d;time--){ if(root[time].l1)remove_node(root[time].l); if(root[time].r1)remove_node(root[time].r); } } } remove(time); } int main(){ while(~scanf("%d%d",&n,&m)){ work(); } return 0; }
求区间第k大
可持久化线段树有一个非常重要的用法——求区间第k大
具体做法
将原序列排序,开一个1-n的线段树,每个节点刚开始都为0
然后按照原序列的顺序,每次找到原序列一个数在排好序的序列的位置,在线段树相应位置+1
这时,第i个线段树维护的就是[1,i]在排好序的序列中的位置,我们可以快速求出[1,i]的第k大
即如果左子节点的值x>=k,那么这个数就是左子节点的第k大,否则是右子节点的第k-x大
如果求[l,r]第k大,只需要对第r棵线段树和第l-1棵作差,然后按照上面的方法即可
int query_k(Node &x,Node &y,int L,int R,int k){//询问x线段树-y线段树中的第k大 if(L==R)return L; int num=x.l->sum-y.l->sum,mid=(L+R)>>1; if(num>=k)return query_k(*x.l,*y.l,L,mid,k); else return query_k(*x.r,*y.r,mid+1,R,k-num); }
例题luoguP3834 【模板】可持久化线段树 1(主席树)
#include<cstdio> #include<algorithm> #include<cstring> using namespace std; #define maxn 200005 struct Node{ Node *l,*r; bool l1,r1;//记录左右子树是不是当前节点创建的,如果不delete就不需要 int lazy,sum; Node(){ memset(this,0,sizeof(Node)); } }root[maxn]; int n,m,a[maxn],b[maxn],st; void build(Node &x,int L,int R){//建树 if(L==R){x.sum=a[L-1];return;} int mid=(L+R)>>1; build(*(x.l1=1,x.l=new Node),L,mid),build(*(x.r1=1,x.r=new Node),mid+1,R); x.sum=x.l->sum+x.r->sum; } void modify(Node &history,Node &x,int L,int R,int w,int l1,int r1){//区间修改 x=history,x.sum+=(r1-l1+1)*w,x.r1=x.l1=0; if(l1<=L&&r1>=R){ x.lazy+=w;return; } int mid=(L+R)>>1; if(l1<=mid)modify(*history.l,*(x.l1=1,x.l=new Node),L,mid,w,l1,min(mid,r1)); if(r1>mid)modify(*history.r,*(x.r1=1,x.r=new Node),mid+1,R,w,max(mid+1,l1),r1); } void change(Node &history,Node &x,int L,int R,int w,int l1){//单点修改 x=history,x.sum+=w,x.l1=x.r1=0;//克隆节点 if(L==R)return;//如果是叶子节点就直接返回 int mid=(L+R)>>1; if(l1<=mid)change(*history.l,*(x.l1=1,x.l=new Node),L,mid,w,l1); else change(*history.r,*(x.r1=1,x.r=new Node),mid+1,R,w,l1); } int query(Node &x,int L,int R,int l1,int r1){//区间查询 if(l1<=L&&r1>=R)return x.sum; int mid=(L+R)>>1,ans=x.lazy*(r1-l1+1); if(l1<=mid)ans+=query(*x.l,L,mid,l1,min(mid,r1)); if(r1>mid)ans+=query(*x.r,mid+1,R,max(mid+1,l1),r1); return ans; } int query_k(Node &x,Node &y,int L,int R,int k){//询问x-y线段树中的第k大 if(L==R)return L; int num=x.l->sum-y.l->sum,mid=(L+R)>>1; if(num>=k)return query_k(*x.l,*y.l,L,mid,k); else return query_k(*x.r,*y.r,mid+1,R,k-num); } void remove_node(Node *x){//节点空间释放 if(x->l1)remove_node(x->l); if(x->r1)remove_node(x->r); delete x; } void remove(int n){//线段树空间释放 for(int i=0;i<=n;i++){ if(root[i].l1)remove_node(root[i].l); if(root[i].r1)remove_node(root[i].r); } memset(root,0,sizeof(root)); } int main(){ int n,m;scanf("%d%d",&n,&m); for(int i=0;i<n;i++)scanf("%d",a+i),b[i]=a[i]; sort(b,b+n),st=unique(b,b+n)-b; build(root[0],1,st); for(int i=0;i<n;i++){ int x=lower_bound(b,b+st,a[i])-b; change(root[i],root[i+1],1,st,1,x+1); } int L,R,k; for(int i=0;i<m;i++){ scanf("%d%d%d",&L,&R,&k); printf("%d\n",b[query_k(root[R],root[L-1],1,st,k)-1]); } remove(n); return 0; }