学习笔记:主席树
What is Zhu Xi Shu?
主席树是可持久化数据结构的一种,他以线段树为原型,除了支持线段树常有的操作外,还支持对历史版本的查询功能。一般被用于可持久化数组/栈,或者就直接作为可持久化线段树使用。
How does it work?
对于要查询历史版本这一操作来说,最朴素的想法是每次修改后都重建一个新版本,这样的空间复杂度是\(O(n*m)\),\(m\)是操作次数,\(n\)是数据范围。
考虑优化。可以发现,对于每一次修改,只会有从根到被修改节点的路径上的\(logn\)个节点发生改变,其余不变。
所以考虑对不变的节点新树与原数共用,然后对发生修改的结点新建节点。空间复杂度\(O(n+m*logn)\)。
然后对于线段树原有的操作正常进行即可,只是要调用对应版本的入口,通过调用入口,我们可以访问任一历史版本。
How to code?(标程,指针实现)
#include<bits/stdc++.h>
using namespace std;
namespace STD
{
#define ll long long
#define rr register
const int SIZE=1e6+4;
int n,m,cnt;
int a[SIZE];
struct node{int val;node *l,*r;}root[SIZE];
void insert(node &before,node &now,int l,int r,int pos,int val)
{
if(l==r){now.val=val;return;}
int mid=(l+r)>>1;
if(pos<=mid)
{
now.r=before.r;
now.l=new node;
insert(*before.l,*now.l,l,mid,pos,val);
}
else
{
now.l=before.l;
now.r=new node;
insert(*before.r,*now.r,mid+1,r,pos,val);
}
}
int query(node &before,node &now,int l,int r,int pos)
{
if(l==r)
{
now.val=before.val;
return now.val;
}
int mid=(l+r)>>1;
now.l=before.l;
now.r=before.r;
if(pos<=mid)
return query(*before.l,*now.l,l,mid,pos);
else
return query(*before.r,*now.r,mid+1,r,pos);
}
void build(node &now,int l,int r)
{
if(l==r){now.val=a[l];return;}
int mid=(l+r)>>1;
now.l=new node;
now.r=new node;
build(*now.l,l,mid),build(*now.r,mid+1,r);
}
int read()
{
rr int x_read=0,y_read=1;
rr char c_read=getchar();
while(c_read<'0'||c_read>'9')
{
if(c_read=='-') y_read=-1;
c_read=getchar();
}
while(c_read<='9'&&c_read>='0')
{
x_read=(x_read*10)+(c_read^48);
c_read=getchar();
}
return x_read*y_read;
}
};
using namespace STD;
int main()
{
n=read(),m=read();
for(rr int i=1;i<=n;i++) a[i]=read();
build(root[0],1,n);
while(m--)
{
int id=read(),op=read(),pos=read();
if(op==1)
{
int val=read();
insert(root[id],root[++cnt],1,n,pos,val);
}
if(op==2)
{
int ans=query(root[id],root[++cnt],1,n,pos);
printf("%d\n",ans);
}
}
}
这里实际上是洛谷板子题的代码,传送门
Expansion(拓展应用)
可持久化数组。
其实就是上面的代码。。。。。。。(逃)
可持久化栈
还是上面的代码,只是要加一个数组记录每一个状态对应的栈顶位置。
#include<bits/stdc++.h>
using namespace std;
namespace STD
{
#define ll long long
#define rr register
const int SIZE=5e5+4;
int n;
int top[SIZE];
int fa[SIZE],to[SIZE];
int dire[SIZE],head[SIZE];
double val[SIZE],depth[SIZE];
double ans[SIZE];
inline void add(int f,int t)
{
static int num1=0;
to[++num1]=t;
dire[num1]=head[f];
head[f]=num1;
}
namespace Prisident_tree
{
struct node
{
int id;
node *l,*r;
}root[SIZE];
void build(node &now,int l,int r)
{
if(l==r) {if(l==1) now.id=1;return;}
rr int mid=(l+r)>>1;
now.l=new node;
now.r=new node;
build(*now.l,l,mid),build(*now.r,mid+1,r);
}
void insert(node &before ,node &now,int l,int r,int pos,int id)
{
if(l==r){now.id=id;return;}
rr int mid=(l+r)>>1;
if(pos<=mid)
{
now.r=before.r;
now.l=new node;
insert(*before.l,*now.l,l,mid,pos,id);
}
else
{
now.l=before.l;
now.r=new node;
insert(*before.r,*now.r,mid+1,r,pos,id);
}
}
int query(rr node now,rr int l,rr int r,rr int pos)
{
if(l==r) return now.id;
rr int mid=(l+r)>>1;
if(pos<=mid)
return query(*now.l,l,mid,pos);
else
return query(*now.r,mid+1,r,pos);
}
};
using namespace Prisident_tree;
int read()
{
rr int x_read=0,y_read=1;
rr char c_read=getchar();
while(c_read<'0'||c_read>'9')
{
if(c_read=='-') y_read=-1;
c_read=getchar();
}
while(c_read<='9'&&c_read>='0')
{
x_read=(x_read*10)+(c_read^48);
c_read=getchar();
}
return x_read*y_read;
}
inline double rate(rr int id1,rr int id2){return (val[id1]-val[id2])/(depth[id1]-depth[id2]);}
int find(int f,int now)
{
int l=1,r=top[f];
rr double a1,a2;
while(l<r)
{
rr int mid=(l+r)>>1;
rr int id1=query(root[f],1,n,mid);
rr int id2=query(root[f],1,n,mid+1);
a1=rate(id2,id1);
a2=rate(now,id2);
if(a1>=a2) r=mid;
else l=mid+1;
}
return l;
}
void dfs(int now)
{
if(depth[now]==1.00)
{
ans[now]=rate(now,1);
insert(root[fa[now]],root[now],1,n,2,now);
top[now]=2;
}
else
{
if(now!=1)
{
int pos=find(fa[now],now);
top[now]=pos+1;
int id=query(root[fa[now]],1,n,pos);
ans[now]=rate(now,id);
insert(root[fa[now]],root[now],1,n,top[now],now);
}
}
for(rr int i=head[now];i;i=dire[i])
{
depth[to[i]]=depth[now]+1.00;
dfs(to[i]);
}
}
};
using namespace STD;
int main()
{
n=read();
for(rr int i=1;i<=n;i++) int x=scanf("%lf",val+i);
for(rr int i=2;i<=n;i++) fa[i]=read(),add(fa[i],i);
build(root[1],1,n);
dfs(1);
for(rr int i=2;i<=n;i++) printf("%.10lf\n",-ans[i]);
}
这里实际上是一道名叫Lost My Music的题的AC代码,里面的主席树就是用来维护可持久化栈的,并且采用二分退栈。
题面自己搜吧,具体思路请看我上一篇博客。
主席树加减
2021.7.22
还是我肤浅了。。。
今天做题时用到了主席树查询区间最值&前趋&后继,用到了这玩意儿。。我不会,被大佬嘲讽了。。。。
主席树支持查询”历史“版本,其实版本可以以任何标准定义。
假如我们以数组下表为版本号,就可以实现区间操作了,也就是加减。
#include<bits/stdc++.h>
using namespace std;
namespace STD
{
#define ll long long
#define rr register
#define inf INT_MAX
const int N=100004;
int n,m;
int a[N],a_[N];
int read()
{
rr int x_read=0,y_read=1;
rr char c_read=getchar();
while(c_read<'0'||c_read>'9')
{
if(c_read=='-') y_read=-1;
c_read=getchar();
}
while(c_read<='9'&&c_read>='0')
{
x_read=(x_read*10)+(c_read^48);
c_read=getchar();
}
return x_read*y_read;
}
struct node
{
int cnt;
node *l,*r;
node(){cnt=0;}
};
class Pst
{
private:
node *root[N];
void Insert(node*,node*,int,int,int);
void Build(node*,int,int);
int Query(node*,node*,int,int,int);
public:
void build()
{
root[0]=new node();
Build(root[0],1,n);
}
void insert(int before,int now,int val)
{
if(root[now]==NULL) root[now]=new node();
Insert(root[before],root[now],1,n,val);
}
int query(int before,int now,int k){return Query(root[before],root[now],1,n,k);}
}t;
void Pst::Build(node *now,int l,int r)
{
if(l==r) return;
int mid=(l+r)>>1;
now->l=new node();
now->r=new node();
Build(now->l,l,mid),Build(now->r,mid+1,r);
}
void Pst::Insert(node *before,node *now,int l,int r,int val)
{
if(l==r){now->cnt++;return;}
int mid=(l+r)>>1;
if(val<=mid)
{
now->r=before->r;
now->l=new node();
now->l->cnt=before->l->cnt;
Insert(before->l,now->l,l,mid,val);
}
else
{
now->r=new node();
now->l=before->l;
now->r->cnt=before->r->cnt;
Insert(before->r,now->r,mid+1,r,val);
}
now->cnt=(now->l->cnt)+(now->r->cnt);
}
int Pst::Query(node *before,node *now,int l,int r,int rank)
{
if(l==r) return r;
int mid=(l+r)>>1;
int num=((now->l->cnt)-(before->l->cnt));
if(num<rank)
return Query(before->r,now->r,mid+1,r,rank-num);
else return Query(before->l,now->l,l,mid,rank);
}
};
using namespace STD;
int main()
{
n=read(),m=read();
for(rr int i=1;i<=n;i++)
a[i]=a_[i]=read();
sort(a_+1,a_+1+n);
int num=unique(a_+1,a_+1+n)-a_-1;
for(rr int i=1;i<=n;i++)
a[i]=lower_bound(a_+1,a_+1+n,a[i])-a_;
t.build();
for(rr int i=1;i<=n;i++)
t.insert(i-1,i,a[i]);
while(m--)
{
int x=read();
int y=read();
int k=read();
int ans=t.query(x-1,y,k);
printf("%d\n",a_[ans]);
}
}
这是本校\(OJ\)上的一道板子题,主要题意是查询给定区间的第\(k\)小数。
原题是北京大学 POJ 2104
#include<bits/stdc++.h>
using namespace std;
namespace STD
{
#define ll long long
#define rr register
#define inf INT_MAX
const int N=1e5+6;
int n,q,type,cnt;
int a[N],x[N],a_[N];
int to[N<<1],dire[N<<1],head[N];
inline void add(int f,int t)
{
static int num1=0;
to[++num1]=t;
dire[num1]=head[f];
head[f]=num1;
}
int read()
{
rr int x_read=0,y_read=1;
rr char c_read=getchar();
while(c_read<'0'||c_read>'9')
{
if(c_read=='-') y_read=-1;
c_read=getchar();
}
while(c_read<='9'&&c_read>='0')
{
x_read=(x_read*10)+(c_read^48);
c_read=getchar();
}
return x_read*y_read;
}
int fa[N],son[N],top[N],size[N],depth[N];
void dfs1(int x)
{
size[x]=1;
for(rr int i=head[x];i;i=dire[i])
{
if(to[i]==fa[x]) continue;
fa[to[i]]=x;
depth[to[i]]=depth[x]+1;
dfs1(to[i]);
size[x]+=size[to[i]];
if(size[to[i]]>size[son[x]]) son[x]=to[i];
}
}
void dfs2(int x)
{
if(x==son[fa[x]]) top[x]=top[fa[x]];
else top[x]=x;
for(rr int i=head[x];i;i=dire[i])
{
if(to[i]==fa[x]) continue;
dfs2(to[i]);
}
}
int LCA(int x,int y)
{
while(top[x]!=top[y])
{
if(depth[top[x]]>depth[top[y]])
x=fa[top[x]];
else y=fa[top[y]];
}
return depth[x]<depth[y]?x:y;
}
class Pst
{
private:
int tot;
int root[N];
int lc[(N<<1)+N*20];
int rc[(N<<1)+N*20];
int sum[(N<<1)+N*20];
void Insert(int before,int &now,int l,int r,int val);
void Build(int &now,int l,int r);
int query_sum(int before,int now,int l,int r,int st,int en);
int query_rank(int before,int now,int l,int r,int rank);
void Out(int now,int l,int r)
{
if(l==r){cout<<l<<' '<<sum[now]<<'\n';return;}
int mid=(l+r)>>1;
Out(lc[now],l,mid),Out(rc[now],mid+1,r);
}
public:
void build(){Build(root[0],1,n+2);}
void insert(int fa,int son,int val){Insert(root[fa],root[son],1,n+2,val);}
int prev(int fa,int son,int val)
{
int rank=query_sum(root[fa],root[son],1,n+2,1,val);
return query_rank(root[fa],root[son],1,n+2,rank);
}
int succ(int fa,int son,int val)
{
int rank=query_sum(root[fa],root[son],1,n+2,1,val);
return query_rank(root[fa],root[son],1,n+2,rank+1);
}
void out(int x){Out(root[x],1,n+2);}
}t;
void Pst::Build(int &now,int l,int r)
{
now=++tot;
if(l==r) return;
int mid=(l+r)>>1;
Build(lc[now],l,mid),Build(rc[now],mid+1,r);
}
void Pst::Insert(int before,int &now,int l,int r,int val)
{
now=++tot;
sum[now]=sum[before];
if(l==r){sum[now]++;return;}
int mid=(l+r)>>1;
if(val<=mid)
{
rc[now]=rc[before];
Insert(lc[before],lc[now],l,mid,val);
}
else
{
lc[now]=lc[before];
Insert(rc[before],rc[now],mid+1,r,val);
}
sum[now]=sum[lc[now]]+sum[rc[now]];
}
int Pst::query_sum(int before,int now,int l,int r,int st,int en)
{
if(st<=l&&r<=en) return sum[now]-sum[before];
int mid=(l+r)>>1;
int ret=0;
if(st<=mid) ret+=query_sum(lc[before],lc[now],l,mid,st,en);
if(mid<en) ret+=query_sum(rc[before],rc[now],mid+1,r,st,en);
return ret;
}
int Pst::query_rank(int before,int now,int l,int r,int rank)
{
if(l==r)
{
if(sum[now]-sum[before])
return l;
return inf;
}
int mid=(l+r)>>1;
int num=(sum[lc[now]]-sum[lc[before]]);
if(num>=rank)
return query_rank(lc[before],lc[now],l,mid,rank);
return query_rank(rc[before],rc[now],mid+1,r,rank-num);
}
void dfs3(int x)
{
t.insert(fa[x],x,a[x]);
for(rr int i=head[x];i;i=dire[i])
{
if(to[i]==fa[x]) continue;
dfs3(to[i]);
}
}
int find(int x)
{
int l=1,r=cnt;
while(l<r)
{
int mid=(l+r+1)>>1;
if(a_[mid]<=x) l=mid;
else r=mid-1;
}
return l;
}
};
using namespace STD;
int main()
{
n=read(),q=read(),type=read();
for(rr int i=1;i<=n;i++) a[i]=a_[i]=read();
for(rr int i=1;i<n;i++)
{
int u=read(),v=read();
add(u,v),add(v,u);
}
sort(a_+1,a_+1+n);
cnt=unique(a_+1,a_+1+n)-a_-1;
for(rr int i=1;i<=n;i++)
a[i]=lower_bound(a_+1,a_+1+cnt,a[i])-a_;
t.build();
dfs1(1);
dfs2(1);
dfs3(1);
//for(rr int i=1;i<=cnt;i++) cout<<a_[i]<<'\n';
//cout<<'\n';
int lastans=0;
while(q--)
{
int r=read(),k=read();
for(rr int i=1;i<=k;i++)
{
x[i]=read();
x[i]=(x[i]-1+lastans*type)%n+1;
}
int lca=x[1];
int r_=find(r);
//cout<<"R: "<<r<<' '<<a_[r_]<<'\n';
//cout<<"R_: "<<r_<<'\n';
for(rr int i=2;i<=k;i++)
lca=LCA(lca,x[i]);
int ans=inf;
for(rr int i=1;i<=k;i++)
{
int prev=t.prev(fa[lca],x[i],r_);
int succ=t.succ(fa[lca],x[i],r_);
//cout<<"PREV: "<<prev<<'\n';
//cout<<"SUCC: "<<succ<<'\n';
if(prev!=inf)
ans=min(ans,abs(a_[prev]-r));
if(succ!=inf)
ans=min(ans,abs(a_[succ]-r));
}
lastans=ans;
printf("%d\n",ans);
for(rr int i=1;i<=k;i++)
x[i]=0;
}
}
这是他在树上的应用,来源于我在今天更新的那篇模拟题的T2
2021.7.16 现役