浅谈树上启发式合并(DSU on tree)
先看例题
CF600E Lomsat gelral
对于这道题,我们要对每一个子树进行查询。观察一下,很容易想到通过信息的合并让父节点继承子节点的信息。但是对于每个点都开一个桶会MLE
sol1 线段树合并
对于每一个节点都开一个线段树,然后进行线段树合并即可。这样做虽然足以通过本题,但空间巨大。
sol2 树上启发式合并
切入正题。如果我们只开一个桶,那空间还是在合理的范围之内。
我们可以这样来实现:对于每一个节点,先搜一遍它的每一个子树,做完就将桶清空清空,以免对其它子树产生影响。然后再暴力遍历一遍所有子树,合并它们的答案(每种颜色出现的个数),然后记录当先节点的答案
但这未免太过于暴力了。我们容易发现,我们可以留一颗子树不去清空,最后将再所有其它子树的答案合并到它上面。容易想到,这颗子树应该为重儿子。
即:对于每一个节点,先搜一遍它的每一个轻儿子,做完就将桶清空清空,以免对其它子树产生影响。然后在做一遍重儿子。接着再暴力遍历一遍所有轻儿子,合并它们的答案(每种颜色出现的个数),然后记录当先节点的答案。
一开始我觉得第一遍遍历轻儿子很费,反正后面还要遍历一遍。其实不然。第一遍遍历是用时间节省了空间。不然会MLE。
最后还要注意一个小细节:每一颗子树的答案可以 \(O(1)\) 的由桶得到。我们可以记录下当前主导颜色的出现次数以及当前的答案,每次增加颜色只需检查那个颜色的出现次数是否大于等于当前主导颜色的出现次数即可。
虽然直觉上感觉这玩意快不了多少,但它是 \(O(nlogn)\)的。
代码实现
#include<bits/stdc++.h>
using namespace std;
#define int long long
inline int read()
{
register int x=0;
register char a=getchar();
while (a<'0'||a>'9') a=getchar();
while (a>='0'&&a<='9') x=(x<<1)+(x<<3)+(a^48),a=getchar();
return x;
}
inline void write(int x) {
if (x < 0) {
x = ~(x - 1);
putchar('-');
}
if (x > 9)
write(x / 10);
putchar(x % 10 + '0');
}
int col[100005];
int n;
vector <int> g[100005];
int siz[100005], son[100005];
int ans[100005], sum[100005];
int maxn, cnt;
void dfs(int u, int fa){//找重儿子
// cout<<u<<endl;
siz[u] = 1;
for(int v:g[u]){
if(v == fa) continue;
dfs(v, u);
siz[u] += siz[v];
if(siz[v] > siz[son[u]]) son[u] = v;
}
}
void Clear(int u, int fa){//清空
sum[col[u]]--;
for(int v:g[u]){
if(v == fa) continue;
Clear(v, u);
}
cnt = maxn = 0;
}
void ins(int u){//计算新加入的颜色对于答案的贡献
sum[col[u]]++;
if(sum[col[u]] == maxn){
cnt += col[u];
}else if(sum[col[u]] > maxn){
maxn = sum[col[u]];
cnt = col[u];
}
}
void dfs3(int u, int fa){//暴力收轻儿子
ins(u);
for(int v:g[u]){
if(v == fa) continue;
dfs3(v, u);
}
}
void dfs2(int u, int fa){
for(int v:g[u]){
if(v == fa or v == son[u]) continue;
dfs2(v, u);//先做轻儿子
Clear(v, u);//做完清空
}
if(son[u]) dfs2(son[u], u);//搜重儿子
for(int v :g[u]){
if(v == son[u] or v == fa) continue;
dfs3(v, u);//暴力搜子树
}
ins(u);
ans[u] = cnt;//记录答案
}
signed main(){
n = read();
for(int i = 1; i <= n; i++) col[i] = read();
for(int i = 1; i < n; i++){
int u = read(), v = read();
g[u].push_back(v);
g[v].push_back(u);
}
dfs(1, 0);
dfs2(1, 0);
for(int i = 1; i <= n; i++){
write(ans[i]);
putchar(' ');
}
return 0;
}