初识树链剖分

首发于摸鱼世界&更好的阅读体验

到现在也只会照着std打板子..

虽然这样,树链剖分还是一个非常优雅的算法。


前置芝士:\(DFS\),线段树

树链剖分可以把树上的区间操作通过把树剖成一条条链,利用线段树数据结构进行维护,从而达到\(O(nlogn)\)的优秀时间复杂度。

比如这样的操作:

在一棵树上,将\(x\)\(y\)路径上点的点权加上\(w\),并要求支持查询两个点\(x,y\)路径间的点权和。

乍一看,两个操作都很简单。修改操作可以用树上差分\(O(1)\)乱搞,静态查询可以用\(LCA\)完成。

但是合起来就没有办法了:每次查询之前都需要\(O(n)\)预处理,数据略大直接\(T\)飞。

于是树剖出场了。


区间修改&查询是线段树的强项,但是它只能对一段连续的区间进行查询。于是我们需要想办法让树上需要操作的路径变成一段连续的区间。

引入一个概念:重儿子,也就是一个节点的儿子中\(size\)最大的。连接到重儿子的边即为重边

重儿子组成的,就是重链

比如在这棵树中,连续的红边组成的就是一条条重链。我们用\(top[u]\)记录节点\(u\)所在重链的顶端。特别地,没有被重边连接的节点,\(top[u]=u\),即它们所在重链的顶端就是自身。注意到,当\(u\)是一条重链的顶端(\(top[u]=u\))时,它的父节点一定在另一条重链上

始终记住我们的目标:把在树上区间操作转化为在一段连续的区间进行操作。

考虑如何用\(DFS\)给树上的每个节点在区间内找到一个合适的位置。我们发现,从根节点出发,优先走重边,这样的\(dfs\)序似乎有点特殊。

例如上图,优先走重边的\(dfs\)序为:\(124798356\)。很显然,这样的\(dfs\)序满足同一条重链上的点\(dfs\)序连续。所以用线段树维护的,就是重链上的信息

这样操作之后,我们可以做到的是:\(O(logn)\)对一条重链上的信息区间修改,区间查询。

对于两个节点\(u,v\),我们可以通过不断地跳重链,直到两个节点在同一条重链上。这个是很好实现的,因为只需要跳到\(fa[top[u]]\),就到了一条新的重链。

代码实现仅树剖部分是不麻烦的。我们需要维护的信息有\(dep\)(节点深度),\(fa\)(父节点),\(son\)(重儿子),\(sz\)(子树节点数,用来判重儿子),这些可以用一次\(dfs\)完成。

void dfs1(int u,int f,int d)//fa,dep,son,sz
{
	fa[u]=f;
	dep[u]=d;
	sz[u]=1;
	for(int i=head[u];i;i=nxt[i])
	{
		int v=to[i];
		if(v!=f)
		{
			dfs1(v,u,d+1);
			sz[u]+=sz[v];
			if(sz[v]>sz[son[u]])son[u]=v;
		}

	}
}

接下来,就需要把这棵树每个节点压到线段树维护的序列的一个位置了。就像上文说的一样,按照优先重边\(dfs\)序压入线段树即可。于是记录一个\(id[i]\)表示原树中节点\(i\)对应的线段树中的下标。\(rk[i]\)反过来记录线段树中下标为\(i\)的原数编号。

由于预处理了父节点,所以\(dfs2\)传参只需要\(u\)(当前节点)和\(t\)(当前重链顶端节点)。在遍历儿子之前先\(dfs2(son[u],t)\),因为\(u\)\(u\)的重儿子在同一条重链上。接下来才遍历轻(非重)儿子\(v\),但是传参为\(dfs2(v,v)\),因为\(v\)就是新的一条重链的起点。

void dfs2(int u,int t)//top,id,rk
{
	top[u]=t;
	id[u]=++tot;
	rk[tot]=u;
	if(!son[u])return;
	dfs2(son[u],t);
	for(int i=head[u];i;i=nxt[i])
	{
		int v=to[i];
		if(v!=fa[u]&&v!=son[u])
			dfs2(v,v);
	}
}

再回到最开始的问题:

在一棵树上,将\(x\)\(y\)路径上点的点权加上\(w\),并要求支持查询两个点\(x,y\)路径间的点权和。

答案就显得很明了了。

如果是查询,先保证\(dep[x]>dep[y]\),然后就和\(LCA\)类似的,利用重链加速:每次把\([top[x],x]\)这条重链的和累加到答案上,再使\(x\)跳到另一条重链上,即\(x=fa[top[x]]\),直到\(x,y\)在同一条重链上,再把两个点之间的信息统计累加一下即可。

int getsum(int x,int y)
{
	int res=0;
	while(top[x]!=top[y])
	{
		if(dep[top[x]]<dep[top[y]])swap(x,y);
		sum=0;
		asksum(1,id[top[x]],id[x]);
		(res+=sum)%=mod;
		x=fa[top[x]];
	}
	if(id[x]>id[y])swap(x,y);
	sum=0;
	asksum(1,id[x],id[y]);
	(res+=sum)%=mod;
	return res;
}

修改同理。

于是我们发现,虽然我们采用了优先重边的\(dfs\)序,但它毕竟遍历的都是自己的儿子节点。所以...还可以支持子树操作。因为一棵子树在重边优先的\(dfs\)序中编号也是连续的。并且这个编号很容易算,因为我们维护了一个\(sz\)信息。所以树中\(x\)节点的子树对应的就是线段树维护的\([id[x],id[x]+sz[x]-1]\)这个区间

于是还是板子一般的线段树区间修改&查询。


可以注意到线段树部分基本没讲,因为每个人写线段树的方法可能不太一样,蒟蒻我分享的只是树剖的思想。

另外,为什么树剖每次操作是\(O(logn)\)呢?利用线段树的子树操作自然是\(O(logn)\),剩下的就是那个像\(LCA\)一样的跳重链。

证明:从任意节点向根节点跳重链,经过的重链和轻边(非重边)都是\(log\)级别的。

考虑到每走一条轻边,子树大小至少翻倍,否则这就不是条轻边了。于是经过的轻边就最多为\(log_2 n\)条。而重链和轻边的交替出现的,所以数量也在这个级别。

于是每次操作就只有\(O(logn)\)的时间复杂度。

模板题

以下是代码

#include<bits/stdc++.h>
#define int long long
#define ls (k<<1)
#define rs (k<<1|1)
using namespace std;
const int N=1e5+10;
struct node
{
	int l,r,w,f;
}t[N<<2];
int a[N];
int n,m,r,mod;
int sum;
int head[N<<1],to[N<<1],nxt[N<<1],cnt;
int sz[N],fa[N],dep[N],son[N];
int top[N],id[N],rk[N],tot;
inline int read()
{
   int x=0,f=1;
   char ch=getchar();
   while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
   while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
   return x*f;
}
void add(int u,int v)
{
	cnt++;
	to[cnt]=v;
	nxt[cnt]=head[u];
	head[u]=cnt;
}
void dfs1(int u,int f)
{
	fa[u]=f;
	sz[u]=1;
	dep[u]=dep[f]+1;
	for(int i=head[u];i;i=nxt[i])
	{
		int v=to[i];
		if(v==f)continue;
		dfs1(v,u);
		sz[u]+=sz[v];
		if(sz[v]>sz[son[u]])son[u]=v;
	}
	return;
}
void dfs2(int u,int t)
{
	top[u]=t;
	id[u]=++tot;
	rk[tot]=u;
	if(!son[u])return;
	dfs2(son[u],t);
	for(int i=head[u];i;i=nxt[i])
	{
		int v=to[i];
		if(v!=fa[u]&&v!=son[u])dfs2(v,v);//新的重链 
	}
}
void build(int k,int l,int r)
{
	t[k].l=l,t[k].r=r;
	if(l==r)
	{
		t[k].w=a[rk[l]];
		return;
	}
	int m=l+r>>1;
	build(ls,l,m);
	build(rs,m+1,r);
	t[k].w=t[ls].w+t[rs].w;
	return;
}
void down(int k)
{
	t[ls].w+=(t[ls].r-t[ls].l+1)*t[k].f;
	t[rs].w+=(t[rs].r-t[rs].l+1)*t[k].f;
	t[ls].f+=t[k].f;
	t[rs].f+=t[k].f;
	t[k].f=0;
}
void addsum(int k,int x,int y,int p)
{
	int l=t[k].l,r=t[k].r;
	if(x<=l&&r<=y)
	{
		t[k].w+=(r-l+1)*p;
		t[k].f+=p;
		return;
	}
	down(k);
	int m=l+r>>1;
	if(x<=m)addsum(ls,x,y,p);
	if(y>m)addsum(rs,x,y,p);
	t[k].w=t[ls].w+t[rs].w;
	return;
}
void asksum(int k,int x,int y)
{
	int l=t[k].l,r=t[k].r;
	if(x<=l&&r<=y)
	{
		sum+=t[k].w;
		return;
	}
	down(k);
	int m=l+r>>1;
	if(x<=m)asksum(ls,x,y);
	if(y>m)asksum(rs,x,y);
	t[k].w=t[ls].w+t[rs].w;
	return;
}
//-----------------------------
int getsum(int x,int y)
{
	int res=0;
	while(top[x]!=top[y])
	{
		if(dep[top[x]]<dep[top[y]])swap(x,y);
		sum=0;
		asksum(1,id[top[x]],id[x]);
		(res+=sum)%=mod;
		x=fa[top[x]];
	}
	if(id[x]>id[y])swap(x,y);
	sum=0;
	asksum(1,id[x],id[y]);
	(res+=sum)%=mod;
	return res;
}
void update(int x,int y,int p)
{
	while(top[x]!=top[y])
	{
		if(dep[top[x]]<dep[top[y]])swap(x,y);
		addsum(1,id[top[x]],id[x],p);
		x=fa[top[x]];
	}
	if(id[x]>id[y])swap(x,y);
	addsum(1,id[x],id[y],p);
	return;
}
signed main()
{
	n=read(),m=read(),r=read(),mod=read();
	for(int i=1;i<=n;i++)a[i]=read();
	for(int i=1;i<n;i++)
	{
		int x=read(),y=read();
		add(x,y),add(y,x);
	}
	dfs1(r,0);
	dfs2(r,r);
	build(1,1,n);
	for(int i=1;i<=m;i++)
	{
		int x,y,z;
		int opt=read();
		if(opt==1)
		{
			x=read(),y=read(),z=read();
			update(x,y,z);
		}
		if(opt==2)
		{
			x=read(),y=read();
			printf("%lld\n",getsum(x,y)%mod);
		}
		if(opt==3)
		{
			x=read(),z=read();
			addsum(1,id[x],id[x]+sz[x]-1,z);
		}
		if(opt==4)
		{
			x=read();
			sum=0;asksum(1,id[x],id[x]+sz[x]-1);
			printf("%lld\n",sum%mod);
		}
	}
	return 0;
}

代码的确是长,也不算容易调,但是真正妙的是利用轻重链的思想进行的化树为链。

感谢阅读。

posted @ 2020-08-06 15:04  摸鱼酱  阅读(125)  评论(0编辑  收藏  举报