codeforces 600E - Lomsat gelral (dsu on tree)
题目链接:https://codeforces.com/problemset/problem/600/E
一直没有点这个技能点,今天跟队友打训练赛,碰到一道 \(dsu\ on\ tree\) 的题写不出来,就回来把这个题写了
\(dsu\ on\ tree\) 运用了轻重链剖分的思想,先处理轻儿子的答案,然后消去轻儿子的影响,最后处理重儿子,并保留重儿子的答案,
时间复杂度 \(O(nlogn)\)
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<cmath>
using namespace std;
typedef long long ll;
const int maxn = 200010;
int n;
int h[maxn], cnt = 0;
struct E{
int to, next;
}e[maxn << 1];
void add(int u, int v){
e[++cnt].next = h[u];
e[cnt].to = v;
h[u] = cnt;
}
ll sum, ans[maxn];
int mx, Son, sz[maxn], son[maxn], c[maxn], cn[maxn * 10];
void dfs1(int u, int par){
sz[u] = 1;
int mx = 0;
for(int i = h[u] ; i != -1; i = e[i].next){
int v = e[i].to;
if(v == par) continue;
dfs1(v, u);
if(sz[v] > mx){
mx = sz[v];
son[u] = v;
}
sz[u] += sz[v];
}
}
void update(int u, int par, int val){
cn[c[u]] += val;
if(cn[c[u]] > mx) mx = cn[c[u]], sum = c[u];
else if(cn[c[u]] == mx) sum += c[u];
for(int i = h[u] ; i != -1 ; i = e[i].next){
int v = e[i].to;
if(v == par || v == Son) continue;
update(v, u, val);
}
}
void dfs2(int u, int par, int is){ // is: 是不是重儿子
// 处理当前子树的轻儿子的答案,处理完后轻儿子的答案被消除
for(int i = h[u] ; i != -1 ; i = e[i].next){
int v = e[i].to;
if(v == par || v == son[u]) continue;
dfs2(v, u, 0);
}
// 再处理重儿子 重儿子的答案保留
sum = 0, mx = 0;
if(son[u]) dfs2(son[u], u, 1);
Son = son[u];
update(u, par, 1); Son = 0; // 计算当前子树的答案,跳过保留答案的重儿子,只暴力统计轻儿子的答案
// for(int i = 1 ; i <= n ; ++i){
// printf("%d ", cn[i]);
// } printf("\n");
ans[u] = sum;
if(!is) update(u, par, -1); // 如果是父亲的轻儿子,就消去当前子树的影响
}
ll read(){ ll s = 0, f = 1; char ch = getchar(); while(ch < '0' || ch > '9'){ if(ch == '-') f = -1; ch = getchar(); } while(ch >= '0' && ch <= '9'){ s = s * 10 + ch - '0'; ch = getchar(); } return s * f; }
int main(){
memset(h, -1, sizeof(h));
n = read();
for(int i = 1 ; i <= n ; ++i){
c[i] = read();
}
int u, v;
for(int i = 1 ; i < n ; ++i){
u = read(), v = read();
add(u, v), add(v, u);
}
dfs1(1, 0);
// for(int i = 1 ; i <= n ; ++i) printf("%d ", son[i]); printf("\n");
dfs2(1, 0, 1);
for(int i = 1 ; i <= n ; ++i){
printf("%lld ", ans[i]);
} printf("\n");
return 0;
}