学习笔记:主席树

What is Zhu Xi Shu?

  主席树是可持久化数据结构的一种,他以线段树为原型,除了支持线段树常有的操作外,还支持对历史版本的查询功能。一般被用于可持久化数组/栈,或者就直接作为可持久化线段树使用。

How does it work?

  对于要查询历史版本这一操作来说,最朴素的想法是每次修改后都重建一个新版本,这样的空间复杂度是\(O(n*m)\)\(m\)是操作次数,\(n\)是数据范围。
  考虑优化。可以发现,对于每一次修改,只会有从根到被修改节点的路径上的\(logn\)个节点发生改变,其余不变。
  所以考虑对不变的节点新树与原数共用,然后对发生修改的结点新建节点。空间复杂度\(O(n+m*logn)\)
  然后对于线段树原有的操作正常进行即可,只是要调用对应版本的入口,通过调用入口,我们可以访问任一历史版本。

How to code?(标程,指针实现)

#include<bits/stdc++.h>
using namespace std;
namespace STD
{
    #define ll long long
    #define rr register 
    const int SIZE=1e6+4;
    int n,m,cnt;
    int a[SIZE];
    struct node{int val;node *l,*r;}root[SIZE];
    void insert(node &before,node &now,int l,int r,int pos,int val)
    {   
        if(l==r){now.val=val;return;}
        int mid=(l+r)>>1;
        if(pos<=mid)
        {
            now.r=before.r;
            now.l=new node;
            insert(*before.l,*now.l,l,mid,pos,val);
        }
        else
        {
            now.l=before.l;
            now.r=new node;
            insert(*before.r,*now.r,mid+1,r,pos,val);
        }
    }
    int query(node &before,node &now,int l,int r,int pos)
    {
        if(l==r)
        {
            now.val=before.val;
            return now.val;
        }
        int mid=(l+r)>>1;
        now.l=before.l;
        now.r=before.r;        
        if(pos<=mid)
            return query(*before.l,*now.l,l,mid,pos);
        else
            return query(*before.r,*now.r,mid+1,r,pos);
    }
    void build(node &now,int l,int r)
    {
        if(l==r){now.val=a[l];return;}
        int mid=(l+r)>>1;
        now.l=new node;
        now.r=new node;
        build(*now.l,l,mid),build(*now.r,mid+1,r);
    }
    int read()
    {
        rr int x_read=0,y_read=1;
        rr char c_read=getchar();
        while(c_read<'0'||c_read>'9')
        {
            if(c_read=='-') y_read=-1;
            c_read=getchar();
        }
        while(c_read<='9'&&c_read>='0')
        {
            x_read=(x_read*10)+(c_read^48);
            c_read=getchar();
        }
        return x_read*y_read;
    }
};
using namespace STD;
int main()
{
    n=read(),m=read();
    for(rr int i=1;i<=n;i++) a[i]=read();
    build(root[0],1,n);
    while(m--)
    {
        int id=read(),op=read(),pos=read();
        if(op==1)
        {
            int val=read();
            insert(root[id],root[++cnt],1,n,pos,val);
        }
        if(op==2)
        {
            int ans=query(root[id],root[++cnt],1,n,pos);
            printf("%d\n",ans);
        }
    }
}

  这里实际上是洛谷板子题的代码,传送门

Expansion(拓展应用)

可持久化数组。

  其实就是上面的代码。。。。。。。(逃)

可持久化栈

  还是上面的代码,只是要加一个数组记录每一个状态对应的栈顶位置。

#include<bits/stdc++.h>
using namespace std;
namespace STD
{
    #define ll long long
    #define rr register 
    const int SIZE=5e5+4;
    int n;
    int top[SIZE];
    int fa[SIZE],to[SIZE];
    int dire[SIZE],head[SIZE];
    double val[SIZE],depth[SIZE];
    double ans[SIZE];
    inline void add(int f,int t)
    {
        static int num1=0;
        to[++num1]=t;
        dire[num1]=head[f];
        head[f]=num1;
    }
    namespace Prisident_tree
    {
        struct node
        {
            int id;
            node *l,*r;
        }root[SIZE];
        void build(node &now,int l,int r)
        {
            if(l==r) {if(l==1) now.id=1;return;}
            rr int mid=(l+r)>>1;
            now.l=new node;
            now.r=new node;
            build(*now.l,l,mid),build(*now.r,mid+1,r);
        }
        void insert(node &before ,node &now,int l,int r,int pos,int id)
        {
            if(l==r){now.id=id;return;}
            rr int mid=(l+r)>>1;
            if(pos<=mid)
            {
                now.r=before.r;
                now.l=new node;
                insert(*before.l,*now.l,l,mid,pos,id);
            }
            else 
            {
                now.l=before.l;
                now.r=new node;
                insert(*before.r,*now.r,mid+1,r,pos,id);
            }
        }
        int query(rr node now,rr int l,rr int r,rr int pos)
        {
            if(l==r) return now.id;
            rr int mid=(l+r)>>1;
            if(pos<=mid) 
                return query(*now.l,l,mid,pos);
            else 
                return query(*now.r,mid+1,r,pos);
        }  
    };
    using namespace Prisident_tree;
    int read()
    {
        rr int x_read=0,y_read=1;
        rr char c_read=getchar();
        while(c_read<'0'||c_read>'9')
        {
            if(c_read=='-') y_read=-1;
            c_read=getchar();
        }   
        while(c_read<='9'&&c_read>='0')
        {
            x_read=(x_read*10)+(c_read^48);
            c_read=getchar();
        }
        return x_read*y_read;
    }
    inline double rate(rr int id1,rr int id2){return (val[id1]-val[id2])/(depth[id1]-depth[id2]);}
    int find(int f,int now)
    {
        int l=1,r=top[f];
        rr double a1,a2;
        while(l<r)
        {
            rr int mid=(l+r)>>1;
            rr int id1=query(root[f],1,n,mid);
            rr int id2=query(root[f],1,n,mid+1);
            a1=rate(id2,id1);
            a2=rate(now,id2);
            if(a1>=a2) r=mid;
            else l=mid+1;
        }
        return l;
    }
    void dfs(int now)
    {
        if(depth[now]==1.00)
        {
            ans[now]=rate(now,1);
            insert(root[fa[now]],root[now],1,n,2,now);
            top[now]=2;
        }
        else
        {
            if(now!=1)
            {
                int pos=find(fa[now],now);
                top[now]=pos+1;
                int id=query(root[fa[now]],1,n,pos);
                ans[now]=rate(now,id);
                insert(root[fa[now]],root[now],1,n,top[now],now);
            }
        }
        for(rr int i=head[now];i;i=dire[i])
        {
            depth[to[i]]=depth[now]+1.00;
            dfs(to[i]);
        }
    }
};
using namespace STD;
int main()
{
    n=read();
    for(rr int i=1;i<=n;i++) int x=scanf("%lf",val+i);
    for(rr int i=2;i<=n;i++) fa[i]=read(),add(fa[i],i);
    build(root[1],1,n);
    dfs(1);
    for(rr int i=2;i<=n;i++) printf("%.10lf\n",-ans[i]);
}

  这里实际上是一道名叫Lost My Music的题的AC代码,里面的主席树就是用来维护可持久化栈的,并且采用二分退栈。
  题面自己搜吧,具体思路请看我上一篇博客。

主席树加减

2021.7.22
  还是我肤浅了。。。
  今天做题时用到了主席树查询区间最值&前趋&后继,用到了这玩意儿。。我不会,被大佬嘲讽了。。。。
  主席树支持查询”历史“版本,其实版本可以以任何标准定义。
  假如我们以数组下表为版本号,就可以实现区间操作了,也就是加减。

#include<bits/stdc++.h>
using namespace std;
namespace STD
{
	#define ll long long
	#define rr register 
	#define inf INT_MAX
	const int N=100004;
	int n,m;
	int a[N],a_[N];
	int read()
	{
		rr int x_read=0,y_read=1;
		rr char c_read=getchar();
		while(c_read<'0'||c_read>'9')
		{
			if(c_read=='-') y_read=-1;
			c_read=getchar();
		}
		while(c_read<='9'&&c_read>='0')
		{
			x_read=(x_read*10)+(c_read^48);
			c_read=getchar();
		}
		return x_read*y_read;
	}
	struct node
	{
		int cnt;
		node *l,*r;
		node(){cnt=0;}
	};
	class Pst
	{
		private:
			node *root[N];
			void Insert(node*,node*,int,int,int);
			void Build(node*,int,int);
			int Query(node*,node*,int,int,int);
		public:
			void build()
			{
				root[0]=new node();
				Build(root[0],1,n);
			}
			void insert(int before,int now,int val)
			{
				if(root[now]==NULL) root[now]=new node();
				Insert(root[before],root[now],1,n,val);
			}
			int query(int before,int now,int k){return Query(root[before],root[now],1,n,k);}
	}t;
	void Pst::Build(node *now,int l,int r)
	{
		if(l==r) return;
		int mid=(l+r)>>1;
		now->l=new node();
		now->r=new node();
		Build(now->l,l,mid),Build(now->r,mid+1,r);
	}
	void Pst::Insert(node *before,node *now,int l,int r,int val)
	{	
		if(l==r){now->cnt++;return;}
		int mid=(l+r)>>1;
		if(val<=mid)   
		{
			now->r=before->r;
			now->l=new node();
			now->l->cnt=before->l->cnt;
			Insert(before->l,now->l,l,mid,val);
		}
		else
		{
			now->r=new node();
			now->l=before->l;
			now->r->cnt=before->r->cnt;
			Insert(before->r,now->r,mid+1,r,val);
		}
		now->cnt=(now->l->cnt)+(now->r->cnt);
	}
	int Pst::Query(node *before,node *now,int l,int r,int rank)
	{
		if(l==r) return r;
		int mid=(l+r)>>1;
		int num=((now->l->cnt)-(before->l->cnt));
		if(num<rank)
			return Query(before->r,now->r,mid+1,r,rank-num);
		else return Query(before->l,now->l,l,mid,rank);
	}
};
using namespace STD;
int main()
{
	n=read(),m=read();
	for(rr int i=1;i<=n;i++)
		a[i]=a_[i]=read();
	sort(a_+1,a_+1+n);
	int num=unique(a_+1,a_+1+n)-a_-1;
	for(rr int i=1;i<=n;i++)
		a[i]=lower_bound(a_+1,a_+1+n,a[i])-a_;
	t.build();
	for(rr int i=1;i<=n;i++)
		t.insert(i-1,i,a[i]);
	while(m--)
	{
		int x=read();
		int y=read();
		int k=read();
		int ans=t.query(x-1,y,k);
		printf("%d\n",a_[ans]);
	}
}

  这是本校\(OJ\)上的一道板子题,主要题意是查询给定区间的第\(k\)小数。
  原题是北京大学 POJ 2104

#include<bits/stdc++.h>
using namespace std;
namespace STD
{
	#define ll long long
	#define rr register
	#define inf INT_MAX
	const int N=1e5+6;
	int n,q,type,cnt;
	int a[N],x[N],a_[N];
	int to[N<<1],dire[N<<1],head[N];
	inline void add(int f,int t)
	{
		static int num1=0;
		to[++num1]=t;
		dire[num1]=head[f];
		head[f]=num1;
	}
	int read()
	{
		rr int x_read=0,y_read=1;
		rr char c_read=getchar();
		while(c_read<'0'||c_read>'9')
		{
			if(c_read=='-') y_read=-1;
			c_read=getchar();
		}
		while(c_read<='9'&&c_read>='0')
		{
			x_read=(x_read*10)+(c_read^48);
			c_read=getchar();
		}
		return x_read*y_read;
	}
	int fa[N],son[N],top[N],size[N],depth[N];
	void dfs1(int x)
	{
		size[x]=1;
		for(rr int i=head[x];i;i=dire[i])
		{
			if(to[i]==fa[x]) continue;
			fa[to[i]]=x;
			depth[to[i]]=depth[x]+1;
			dfs1(to[i]);
			size[x]+=size[to[i]];
			if(size[to[i]]>size[son[x]]) son[x]=to[i];
		}
	
	}
	void dfs2(int x)
	{
		if(x==son[fa[x]]) top[x]=top[fa[x]];
		else top[x]=x;
		for(rr int i=head[x];i;i=dire[i])
		{
			if(to[i]==fa[x]) continue;
			dfs2(to[i]);
		}
	}
	int LCA(int x,int y)
	{
		while(top[x]!=top[y])
		{
			if(depth[top[x]]>depth[top[y]])
				x=fa[top[x]];
			else y=fa[top[y]];
		}
		return depth[x]<depth[y]?x:y;
	}
	class Pst
	{
		private:
			int tot;
			int root[N];
			int lc[(N<<1)+N*20];
			int rc[(N<<1)+N*20];
			int sum[(N<<1)+N*20];
			void Insert(int before,int &now,int l,int r,int val);
			void Build(int &now,int l,int r);
			int query_sum(int before,int now,int l,int r,int st,int en);
			int query_rank(int before,int now,int l,int r,int rank);
			void Out(int now,int l,int r)
			{
				if(l==r){cout<<l<<' '<<sum[now]<<'\n';return;}
				int mid=(l+r)>>1;
				Out(lc[now],l,mid),Out(rc[now],mid+1,r);
			}
		public:
			void build(){Build(root[0],1,n+2);}
			void insert(int fa,int son,int val){Insert(root[fa],root[son],1,n+2,val);}
			int prev(int fa,int son,int val)
			{
				int rank=query_sum(root[fa],root[son],1,n+2,1,val);
				return query_rank(root[fa],root[son],1,n+2,rank);
			}
			int succ(int fa,int son,int val)
			{
				int rank=query_sum(root[fa],root[son],1,n+2,1,val);				
				return query_rank(root[fa],root[son],1,n+2,rank+1);
			}
			void out(int x){Out(root[x],1,n+2);}
	}t;
	void Pst::Build(int &now,int l,int r)
	{
		now=++tot;
		if(l==r) return;
		int mid=(l+r)>>1;
		Build(lc[now],l,mid),Build(rc[now],mid+1,r);
	}
	void Pst::Insert(int before,int &now,int l,int r,int val)
	{
		now=++tot;
		sum[now]=sum[before];
		if(l==r){sum[now]++;return;}
		int mid=(l+r)>>1;
		if(val<=mid)
		{
			rc[now]=rc[before];
			Insert(lc[before],lc[now],l,mid,val);
		}
		else
		{
			lc[now]=lc[before];
			Insert(rc[before],rc[now],mid+1,r,val);
		}
		sum[now]=sum[lc[now]]+sum[rc[now]];
	}
	int Pst::query_sum(int before,int now,int l,int r,int st,int en)
	{
		if(st<=l&&r<=en) return sum[now]-sum[before];
		int mid=(l+r)>>1;
		int ret=0;
		if(st<=mid) ret+=query_sum(lc[before],lc[now],l,mid,st,en);
		if(mid<en) ret+=query_sum(rc[before],rc[now],mid+1,r,st,en);
		return ret;
	}
	int Pst::query_rank(int before,int now,int l,int r,int rank)
	{
		if(l==r) 
		{
			if(sum[now]-sum[before])
				return l;
			return inf;
		}
		int mid=(l+r)>>1;
		int num=(sum[lc[now]]-sum[lc[before]]);
		if(num>=rank)
			return query_rank(lc[before],lc[now],l,mid,rank);
		return query_rank(rc[before],rc[now],mid+1,r,rank-num);
	}
	void dfs3(int x)
	{
		t.insert(fa[x],x,a[x]);
		for(rr int i=head[x];i;i=dire[i])
		{
			if(to[i]==fa[x]) continue;
			dfs3(to[i]);
		}
	}
	int find(int x)
	{
		int l=1,r=cnt;
		while(l<r)
		{
			int mid=(l+r+1)>>1;
			if(a_[mid]<=x) l=mid;
			else r=mid-1;
		}
		return l;
	}
};
using namespace STD;
int main()
{
	n=read(),q=read(),type=read();
	for(rr int i=1;i<=n;i++) a[i]=a_[i]=read();
	for(rr int i=1;i<n;i++)
	{
		int u=read(),v=read();
		add(u,v),add(v,u);
	}
	sort(a_+1,a_+1+n);
	cnt=unique(a_+1,a_+1+n)-a_-1;
	for(rr int i=1;i<=n;i++)
		a[i]=lower_bound(a_+1,a_+1+cnt,a[i])-a_;
	t.build();
	dfs1(1);
	dfs2(1);
	dfs3(1);
	//for(rr int i=1;i<=cnt;i++) cout<<a_[i]<<'\n';
	//cout<<'\n';
	int lastans=0;
	while(q--)
	{	
		int r=read(),k=read();
		for(rr int i=1;i<=k;i++)
		{
			x[i]=read();
			x[i]=(x[i]-1+lastans*type)%n+1;
		}
		int lca=x[1];
		int r_=find(r);
		//cout<<"R: "<<r<<' '<<a_[r_]<<'\n';
		//cout<<"R_: "<<r_<<'\n';
		for(rr int i=2;i<=k;i++)
			lca=LCA(lca,x[i]);
		int ans=inf;
		for(rr int i=1;i<=k;i++)
		{
			int prev=t.prev(fa[lca],x[i],r_);
			int succ=t.succ(fa[lca],x[i],r_);
			//cout<<"PREV: "<<prev<<'\n';
			//cout<<"SUCC: "<<succ<<'\n';
			if(prev!=inf)
				ans=min(ans,abs(a_[prev]-r));
			if(succ!=inf)
				ans=min(ans,abs(a_[succ]-r));
		}
		lastans=ans;
		printf("%d\n",ans);
		for(rr int i=1;i<=k;i++)
			x[i]=0;
	}
}

  这是他在树上的应用,来源于我在今天更新的那篇模拟题的T2
2021.7.16 现役

posted @ 2021-07-16 17:16  Geek_kay  阅读(73)  评论(0编辑  收藏  举报