Dsu on tree 学习笔记
\(\text{dsu on tree}\) 略解
简介
首先 \(\text{dsu on tree}\) 和并查集并没有关系,其用来处理一类树上问题,一般有两个特征:
- 不带修改
- 询问与子树有关
\(\text{dsu on tree}\) 可以十分方便的在 \(O(nlogn)\) 的时间复杂度内解决。
大致思路
\(\text{dsu on tree}\) 利用了重链剖分中重儿子的思想来进行暴力。
例如一道例题 CF600E:求每个子树内出现次数最多的颜色之和
\(O(n^2)\) 暴力十分显然,但是可以发现一个性质:父子之间的信息共享,而兄弟之间的信息不共享,也就是计算完最后一个子树的信息后,可以不用清空,其信息可以保留下来继续给父亲使用。
所以我们想到使最后一个遍历的子树尽可能大,也就是 重儿子。
算法流程
设当前求到 \(u\) 的答案 \(ans_u\),算法大致分为 \(5\) 步:
- 计算 轻儿子 \(v\) 的 \(ans_v\)
- 计算 \(u\) 重儿子 \(son_u\) 的 \(ans_{son_u}\),并将 \(son_u\) 的信息保留继续使用
- 再暴力计算每个轻儿子的信息
- 更新 \(ans_u\)
- 如果 \(u\) 不为重儿子,则暴力删去 \(u\) 的信息
”暴力计算 \(v\)“ 指将以 \(v\) 为根的子树遍历一遍计算信息(也可能因题目而异吧)
复杂度
首先有一个重要的性质:一个节点到根路径上的轻边数不超过 \(logn\),证明:
由轻重儿子的性质可知:对于 \(u\) 的任意轻儿子 \(v\) 有 \(siz_v \leq \frac{siz_u}{2}\)
因此每经过一条轻边 \(siz/2\),那么任意点开始往叶子节点走经过轻边数量最多不超过 \(logn\) 条
得证
再考虑每个点 \(v\) 会被计算多少次,按其到根的路径上的轻/重边分为两类讨论:
- 对于每条轻边,都需要单独计算一次 \(v\) 的信息,由以上性质知不超过 \(logn\) 次
- 对于 \(v\) 到根路径上的每条重边,是不需要再计算 \(v\) 的
所以对于节点 \(v\),一共会被计算 \(logn + 1\) 次(\(1\) 为计算 \(ans_v\))
综上,若计算一个点的信息为 \(O(1)\),则该算法时间复杂度为 \(O(nlogn)\)。
例题
CF600E
第一次打 \(Code\) 有点丑......
Code
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define N 100000
#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define Fo(i, u) for(int i = head[u]; i; i = edge[i].next)
#define ll long long
void read(int &x) {
char ch = getchar(); x = 0;
while (ch < '0' || ch > '9') ch = getchar();
while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + ch - 48, ch = getchar();
}
struct EDGE { int next, to; } edge[N << 1];
int head[N + 1], col[N + 1], h[N + 1], sz[N + 1], son[N + 1], las[N + 1];
ll ans[N + 1];
int n;
int cnt_edge = 1;
void Add(int u, int v) { edge[ ++ cnt_edge ] = (EDGE) { head[u], v }, head[u] = cnt_edge; }
void Link(int u, int v) { Add(u, v), Add(v, u); }
void Dfs1(int u, int la) {
sz[u] = 1, son[u] = 0;
Fo(i, u) if (i != la) {
Dfs1(edge[i].to, i ^ 1);
if (sz[edge[i].to] > sz[son[u]])
son[u] = edge[i].to;
sz[u] += sz[edge[i].to];
}
}
int max_h = 0;
ll sum = 0;
void Add1(int c, int d) {
h[c] += d;
if (h[c] > max_h) max_h = h[c], sum = c;
else if (h[c] == max_h) sum += c;
}
void Dfs3(int u, int la, int d) {
Add1(col[u], d);
Fo(i, u) if (i != la)
Dfs3(edge[i].to, i ^ 1, d);
}
void Dfs2(int u, int fa, int opt) {
int v = 0;
Fo(i, u) if ((v = edge[i].to) != fa && v != son[u])
Dfs2(v, u, 1);
if (son[u]) Dfs2(son[u], u, 0);
Fo(i, u) if ((v = edge[i].to) != fa && v != son[u])
Dfs3(v, i ^ 1, 1);
Add1(col[u], 1);
ans[u] = sum;
if (opt) {
Fo(i, u) if ((v = edge[i].to) != fa)
Dfs3(v, i ^ 1, -1);
Add1(col[u], -1);
sum = max_h = 0;
}
}
int main() {
read(n);
fo(i, 1, n) read(col[i]);
for (int i = 1, x, y; i < n; i ++)
read(x), read(y), Link(x, y);
Dfs1(1, 0);
Dfs2(1, 0, 0);
fo(i, 1, n) printf("%lld ", ans[i]);
return 0;
}
CF741D
Solution
由回文串的性质可知:区间内之多只有一个字符出现奇数次。
借此可以将统计出现次数转化为异或,可以用大小为 \(2^{22}\) 的状态表示从根开始的路径上每个字符出现次数的奇偶性。
设 \(dis_{u}\) 为从根到 \(x\) 的路径上字符的状态,那么任意路径 \((u, v)\) 的字符状态就可以表示为 \(dis_{(u, v)} = dis_u \oplus dis_v \oplus dis_{lca} \oplus dis_{lca}\),由异或的性质可知即为 \(dis_u \oplus dis_v\),而距离就是 \(dep_u + dep_v - 2dep_{lca}\)。
所以只需要用大小 \(2^{22}\) 的桶存下每个状态的最深深度,\(O(22)\) 可以求出一个点对答案的贡献,每个点 \(u\) 的 \(ans_u\) 为其子树 \(ans\) 与经过 \(u\) 最长符合条件路径的最大值。
剩下的就基本是 \(\text{dsu on tree}\) 的模板了。
Code
#include <cstdio>
using namespace std;
#define N 500000
#define M 22
#define inf 10000
#define fo(i, x, y) for(int i = x, end_##i = y; i <= end_##i; i ++)
#define fd(i, x, y) for(int i = x, end_##i = y; i >= end_##i; i --)
#define Fo(i, u) for(int i = head[u]; i; i = edge[i].next)
void read(int &x) {
char ch = getchar(); x = 0;
while (ch < '0' || ch > '9') ch = getchar();
while (ch >= '0' && ch <= '9') x = (x << 1) + (x << 3) + ch - 48, ch = getchar();
}
struct EDGE { int next, to; } edge[N << 1];
int head[N + 1], col[N + 1], f[1 << M], d[N + 1], sz[N + 1], son[N + 1], fa[N + 1], c[N + 1], ans[N + 1];
int n;
int cnt_edge = 0;
void Add(int u, int v) { edge[ ++ cnt_edge ] = (EDGE) { head[u], v }, head[u] = cnt_edge; }
int max(int x, int y) { return x > y ? x : y; }
void Init() {
d[1] = 1;
fo(i, 2, n) c[i] = c[fa[i]] ^ (1 << col[i]), d[i] = d[fa[i]] + 1;
fd(i, n, 1) {
if (++ sz[i] > sz[son[fa[i]]])
son[fa[i]] = i;
sz[fa[i]] += sz[i];
}
fo(i, 2, n) if (i != son[fa[i]])
Add(fa[i], i);
fo(i, 0, (1 << M) - 1) f[i] = -inf;
}
int Get_d(int x) {
int dep = f[x];
fo(i, 0, M - 1)
dep = max(dep, f[x ^ (1 << i)]);
return dep;
}
int Dfs1(int u) {
int dep = Get_d(c[u]) + d[u];
Fo(i, u) dep = max(dep, Dfs1(edge[i].to));
if (son[u]) dep = max(dep, Dfs1(son[u]));
return dep;
}
void Updata(int x, int dep) { f[x] = max(f[x], dep); }
void Dfs2(int u) {
Updata(c[u], d[u]);
Fo(i, u) Dfs2(edge[i].to);
if (son[u]) Dfs2(son[u]);
}
void Back(int x) { f[x] = -inf; }
void Dfs3(int u) {
Back(c[u]);
Fo(i, u) Dfs3(edge[i].to);
if (son[u]) Dfs3(son[u]);
}
void Solve(int u, int opt) {
ans[u] = 0;
Fo(i, u) Solve(edge[i].to, 1), ans[u] = max(ans[u], ans[edge[i].to]);
if (son[u]) Solve(son[u], 0), ans[u] = max(ans[u], ans[son[u]]);
ans[u] = max(ans[u], Get_d(c[u]) - d[u]);
Updata(c[u], d[u]);
Fo(i, u)
ans[u] = max(ans[u], Dfs1(edge[i].to) - (d[u] << 1)), Dfs2(edge[i].to);
if (opt) Dfs3(u);
}
int main() {
read(n);
fo(i, 2, n)
read(fa[i]), col[i] = getchar() - 'a';
Init();
Solve(1, 0);
fo(i, 1, n) printf("%d ", ans[i]);
return 0;
}