[学习笔记]dsu on tree
这种不怎么难写的东西,我学得快忘得也快,也是给自己加深印象,同时留个自己(大概)能看懂的讲解好复习……qwq
先说是什么
dsu on tree中的dsu就是Disjoint Set Union,虽然整个算法跟并茶几(话说并茶几名字好多啊……)没有任何关系……硬要说就是借用了启发式合并的思想吧……
这个算法是拿来解决树上对子树内答案的询问的,当然它并不支持修改
它在暴力的基础上,借助轻重链剖分的性质把复杂度降低到了\(O(n \log n)\)
大致过程
遍历每一个节点,先递归解决轻儿子,完成后消除递归产生的影响
然后解决重儿子,但不消除影响
将轻儿子的答案合并上来
消除上一个过程中轻儿子产生的影响
拿一个例题来说:CF600E
题意:树上每个节点有一个颜色,求每棵子树中出现次数最多的颜色(可能有多个)之和
首先是轻重链剖分,处理出每个节点的重儿子
void dfs(int u, int fa) {
size[u] = 1;
for (int i = G.head[u]; ~i; i = G[i].next) {
int v = G[i].v;
if (v == fa) continue;
dfs(v, u);
size[u] += size[v];
if (!heavy[u] || size[v] > size[heavy[u]]) heavy[u] = v;
}
}
然后遍历节点,按上面的流程来(具体看注释)
void update(int u, int fa, int val, const int &hvy) {//val为1暴力合并统计轻儿子的答案,为-1清除对cnt的影响
cnt[col[u]] += val;
if (val > 0 && cnt[col[u]] >= max_cnt) {
if (cnt[col[u]] > max_cnt) sum = 0, max_cnt = cnt[col[u]];
sum += (LL)col[u];
}
for (int i = G.head[u]; ~i; i = G[i].next) {
int v = G[i].v;
if (v == fa || v == hvy) continue;
update(v, u, val, hvy);
}
}
void dfs(int u, int fa, int opt) {//opt为0表示需要清除掉u的影响,为1表示不需要
for (int i = G.head[u]; ~i; i = G[i].next) {
int v = G[i].v;
if (v == fa || v == heavy[u]) continue;
dfs(v, u, 0);//递归解决轻儿子,完成后清除影响
}
if (heavy[u]) dfs(heavy[u], u, 1);//解决重儿子,保留影响
update(u, fa, 1, heavy[u]);//合并轻儿子的答案
ans[u] = sum;
if (!opt) update(u, fa, -1, 0), sum = 0, max_cnt = 0;//清除影响
}
最后是完整代码:
PS.怎么网上的博客代码个个不一样啊……,蒟蒻我懵逼了好长时间才看明白qwq
#include <cstdio>
#include <cstring>
#include <iostream>
#define MAXN 100005
typedef long long LL;
struct Graph {
struct Edge {
int v, next;
Edge(int _v = 0, int _n = 0):v(_v), next(_n) {}
} edge[MAXN << 1];
int head[MAXN], cnt;
void init() { memset(head, -1, sizeof head); cnt = 0; }
void add_edge(int u, int v) { edge[cnt] = Edge(v, head[u]); head[u] = cnt++; }
void insert(int u, int v) { add_edge(u, v); add_edge(v, u); }
Edge & operator [](int x) { return edge[x]; }
} G;
int col[MAXN], size[MAXN], heavy[MAXN], val[MAXN], cnt[MAXN], N;
LL sum, max_cnt, ans[MAXN];
void dfs(int, int);
void dfs(int, int, int);
void update(int, int, int, const int &);
int main() {
G.init();
scanf("%d", &N);
for (int i = 1; i <= N; ++i) scanf("%d", col + i);
for (int i = 1; i < N; ++i) {
int u, v;
scanf("%d%d", &u, &v);
G.insert(u, v);
}
dfs(1, 0);
dfs(1, 0, 0);
for (int i = 1; i <= N; ++i) printf("%I64d ", ans[i]);
return 0;
}
void dfs(int u, int fa) {
size[u] = 1;
for (int i = G.head[u]; ~i; i = G[i].next) {
int v = G[i].v;
if (v == fa) continue;
dfs(v, u);
size[u] += size[v];
if (!heavy[u] || size[v] > size[heavy[u]]) heavy[u] = v;
}
}
void update(int u, int fa, int val, const int &hvy) {
cnt[col[u]] += val;
if (val > 0 && cnt[col[u]] >= max_cnt) {
if (cnt[col[u]] > max_cnt) sum = 0, max_cnt = cnt[col[u]];
sum += (LL)col[u];
}
for (int i = G.head[u]; ~i; i = G[i].next) {
int v = G[i].v;
if (v == fa || v == hvy) continue;
update(v, u, val, hvy);
}
}
void dfs(int u, int fa, int opt) {
for (int i = G.head[u]; ~i; i = G[i].next) {
int v = G[i].v;
if (v == fa || v == heavy[u]) continue;
dfs(v, u, 0);
}
if (heavy[u]) dfs(heavy[u], u, 1);
update(u, fa, 1, heavy[u]);
ans[u] = sum;
if (!opt) update(u, fa, -1, 0), sum = 0, max_cnt = 0;
}
//Rhein_E
最后是复杂度证明
轻重链剖分保证了每个节点到根的路径上轻边条数不超过\(\log n\)
每个节点被访问一次,要么是它的祖先节点暴力统计轻儿子/消除影响的时候,要么是它自己统计答案的时候
前者\(O(\log n)\)次,后者\(1\)次
所以总复杂度是\(O(n \log n)\)的