树上启发式合并(dsu on tree)

树上启发式合并

在做一类离线,答案与子树贡献有关的树上统计问题时暴力做法需要dfs,依次处理每个子结点,并在回溯时判断是否满足要求,然后清空避免对其兄弟节点产生影响,这会使时间复杂度退化至 \(O(N^2)\)
考虑优化,我们发现在dfs时x的最后一个子节点不需要清空,它的贡献可以直接加入x的答案中。
根据这一性质,我们可以预处理出每个节点的最大子树,该子树的根节点称为重儿子,并在dfs时选择最后处理重儿子的贡献。
可以证明这样优化后的时间复杂度为 \(O(N\log N)\)

证明
显然一个节点需要清空,当且仅当它或祖先为轻儿子(意思是不是重儿子)。
于是易知一个节点被清空(注意当它的祖先节点被清空时也算作该节点被清空)的最大次数等于它到根路径上的轻边数。
又因为在一棵有 \(n\) 个节点的 \(k\) 叉树中(\(k>1\) ,因为一条链构成的树一定只需计算一次),每向上走一条边,子树大小就至少扩大一倍,所以在任意一个节点到根的路径上至多有 \(\log_k N\) 条轻边。
所以,一个节点会被清空 \(O(\log N)\) 次。
总时间复杂度为 \(O(N\log N)\)
\(\texttt{证毕。}\)

这就是树上启发式合并思想,又称dsu on tree或静态链分治。它可以理解成一种利用重链剖分的性质来优化计算子树贡献的做法。
现在看一道例题:
\(\mathbf{CF6000E\ Lomsat\ gelral}\)
求每棵子树中出现次数最多的颜色(可以有多种颜色)的编号之和。
\(heavyson[i]\) 表示 \(i\) 是否为重儿子,对重儿子最后计算贡献,不清空。
代码如下:

//CF600E Lomsat gelal
#include<iostream>
#include<algorithm>
#include<string>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#define N 100005
using namespace std;
struct Edge{
	int nxt,to;
}e[N<<1];
int n,tot,head[N],co[N],siz[N];
bool hson[N];
long long ans[N],a[N],maxa,cntans;
inline void addedge(int x,int y)
{
	e[++tot]=(Edge){head[x],y};
	head[x]=tot;
}
void getsiz(int x,int pre)
{
	int maxsiz=0,Hson=0;
	siz[x]=1;
	for(int i=head[x],y;i;i=e[i].nxt)
	{
		y=e[i].to;
		if(y==pre)	continue;
		getsiz(y,x);
		siz[x]+=siz[y];
		if(siz[y]>maxsiz)
			maxsiz=siz[y],Hson=y;
	}
	hson[Hson]=Hson?1:0;
}
void getans(int x,int pre,int Hson)
{
	a[co[x]]++;
	if(a[co[x]]>maxa)
		maxa=a[co[x]],cntans=co[x];
	else if(a[co[x]]==maxa)
		cntans+=co[x];
	for(int i=head[x],y;i;i=e[i].nxt)
	{
		y=e[i].to;
		if(y==pre||y==Hson)	continue;
		getans(y,x,Hson);
	}
}
void clear(int x,int pre)
{
	a[co[x]]--;
	for(int i=head[x],y;i;i=e[i].nxt)
	{
		y=e[i].to;
		if(y==pre)	continue;
		clear(y,x);
	}
}
void solve(int x,int pre)
{
	int Hson=0;
	for(int i=head[x],y;i;i=e[i].nxt)
	{
		y=e[i].to;
		if(y==pre)	continue;
		if(!hson[y])
		{
			solve(y,x);
			clear(y,x),maxa=cntans=0;
		}
		else	Hson=y;
	}
	if(Hson)
		solve(Hson,x);
	getans(x,pre,Hson);
	ans[x]=cntans;
}
int main()
{
	scanf("%d",&n);
	for(int i=1;i<=n;i++)
		scanf("%d",co+i);
	for(int i=1,x,y;i<=n-1;i++)
		scanf("%d%d",&x,&y),addedge(x,y),addedge(y,x);
	getsiz(1,0);
	solve(1,0);
	for(int i=1;i<=n;i++)
		printf("%lld ",ans[i]);
	return 0;
}
posted @ 2020-10-24 13:58  LZShuing  阅读(133)  评论(0编辑  收藏  举报