树上启发式合并(dsu on tree)
树上启发式合并
在做一类离线,答案与子树贡献有关的树上统计问题时暴力做法需要dfs,依次处理每个子结点,并在回溯时判断是否满足要求,然后清空避免对其兄弟节点产生影响,这会使时间复杂度退化至 \(O(N^2)\)。
考虑优化,我们发现在dfs时x的最后一个子节点不需要清空,它的贡献可以直接加入x的答案中。
根据这一性质,我们可以预处理出每个节点的最大子树,该子树的根节点称为重儿子,并在dfs时选择最后处理重儿子的贡献。
可以证明这样优化后的时间复杂度为 \(O(N\log N)\)。
证明:
显然一个节点需要清空,当且仅当它或祖先为轻儿子(意思是不是重儿子)。
于是易知一个节点被清空(注意当它的祖先节点被清空时也算作该节点被清空)的最大次数等于它到根路径上的轻边数。
又因为在一棵有 \(n\) 个节点的 \(k\) 叉树中(\(k>1\) ,因为一条链构成的树一定只需计算一次),每向上走一条边,子树大小就至少扩大一倍,所以在任意一个节点到根的路径上至多有 \(\log_k N\) 条轻边。
所以,一个节点会被清空 \(O(\log N)\) 次。
总时间复杂度为 \(O(N\log N)\) 。
\(\texttt{证毕。}\)
这就是树上启发式合并思想,又称dsu on tree或静态链分治。它可以理解成一种利用重链剖分的性质来优化计算子树贡献的做法。
现在看一道例题:
\(\mathbf{CF6000E\ Lomsat\ gelral}\)
求每棵子树中出现次数最多的颜色(可以有多种颜色)的编号之和。
记 \(heavyson[i]\) 表示 \(i\) 是否为重儿子,对重儿子最后计算贡献,不清空。
代码如下:
//CF600E Lomsat gelal
#include<iostream>
#include<algorithm>
#include<string>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#define N 100005
using namespace std;
struct Edge{
int nxt,to;
}e[N<<1];
int n,tot,head[N],co[N],siz[N];
bool hson[N];
long long ans[N],a[N],maxa,cntans;
inline void addedge(int x,int y)
{
e[++tot]=(Edge){head[x],y};
head[x]=tot;
}
void getsiz(int x,int pre)
{
int maxsiz=0,Hson=0;
siz[x]=1;
for(int i=head[x],y;i;i=e[i].nxt)
{
y=e[i].to;
if(y==pre) continue;
getsiz(y,x);
siz[x]+=siz[y];
if(siz[y]>maxsiz)
maxsiz=siz[y],Hson=y;
}
hson[Hson]=Hson?1:0;
}
void getans(int x,int pre,int Hson)
{
a[co[x]]++;
if(a[co[x]]>maxa)
maxa=a[co[x]],cntans=co[x];
else if(a[co[x]]==maxa)
cntans+=co[x];
for(int i=head[x],y;i;i=e[i].nxt)
{
y=e[i].to;
if(y==pre||y==Hson) continue;
getans(y,x,Hson);
}
}
void clear(int x,int pre)
{
a[co[x]]--;
for(int i=head[x],y;i;i=e[i].nxt)
{
y=e[i].to;
if(y==pre) continue;
clear(y,x);
}
}
void solve(int x,int pre)
{
int Hson=0;
for(int i=head[x],y;i;i=e[i].nxt)
{
y=e[i].to;
if(y==pre) continue;
if(!hson[y])
{
solve(y,x);
clear(y,x),maxa=cntans=0;
}
else Hson=y;
}
if(Hson)
solve(Hson,x);
getans(x,pre,Hson);
ans[x]=cntans;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
scanf("%d",co+i);
for(int i=1,x,y;i<=n-1;i++)
scanf("%d%d",&x,&y),addedge(x,y),addedge(y,x);
getsiz(1,0);
solve(1,0);
for(int i=1;i<=n;i++)
printf("%lld ",ans[i]);
return 0;
}