「dsu on tree」学习笔记(优雅的暴力)
前置芝士
- 半熟练剖分
基本概念
-
\(dsu\):并查集
-
\(on tree\):在树上
即在树上的并查集
行吧,跟名字没多大关系。
可以通俗的理解为暴力,一般用来处理以下性质的问题:
-
一个根的答案由子树来贡献
-
不含任何修改操作
有同学可能就问了,这用点分治不就可以了吗?
两者有点差别:
-
点分治一般处理的是无根树,而 \(dsu\) 处理的是有根树。
-
点分治的效率是 \(nlog_n^2\),而 \(dsu\) 的效率是 \(nlog_n\) 。
基本思路
拿一个例题:CF600E Lomsat gelral
首先考虑暴力如何去做:
-
枚举每一个根,直接暴搜它的子树所有节点,统计答案即可,时间负责度 \(\Theta (n^2)\) 。
-
预处理出来每个子树的颜色序列,存起来,直接上传到其父亲即可,空间复杂度 \(\Theta (n^2)\) 。
显然两种暴力都不可取,但是两种暴力都各有所长,我们就将其长处融合在一起:\(dsu\) 。
-
根据暴力一的思路,我们可以将其每个轻儿子的子树暴力统计,每次统计完后直接清除即可。
-
根据暴力二的思路,我们可以不用将重儿子的答案清除,直接用于统计答案即可。
所以这样,我们既保证了时间效率的较优,又保证了空间负责度的合理,就可以很好的处理这个问题了。
- 小细节(决定成败):当你暴力统计轻儿子的子树答案时,将重儿子的子树跳过即可。
\(u1s1\),通俗(\(dsu\))易懂。
板板题代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define int long long
using namespace std;
const int maxn = 1e5 + 50, INF = 0x3f3f3f3f;
inline int read () {
int x = 0, w = 1;
char ch = getchar ();
for (; ch < '0' || ch > '9'; ch = getchar ()) if (ch == '-') w = -1;
for (; ch >= '0' && ch <= '9'; ch = getchar ()) x = x * 10 + ch - '0';
return x * w;
}
int n;
int col[maxn], ans[maxn];
struct Edge {
int to, next;
}e[maxn << 1];
int tot, head[maxn];
inline void Add (register int u, register int v) {
e[++ tot].to = v;
e[tot].next = head[u];
head[u] = tot;
}
int size[maxn], son[maxn];
inline void DFS1 (register int u, register int fa) { // 熟练剖分,统计重儿子
size[u] = 1;
for (register int i = head[u]; i; i = e[i].next) {
register int v = e[i].to;
if (v == fa) continue;
DFS1 (v, u);
size[u] += size[v];
if (size[son[u]] < size[v]) son[u] = v;
}
}
int cnt[maxn], sum, maxx, Son;
inline void Query (register int u, register int fa, register int w) {
cnt[col[u]] += w;
if (cnt[col[u]] > maxx) maxx = cnt[col[u]], sum = col[u]; // 更新答案
else if (cnt[col[u]] == maxx) sum += col[u];
for (register int i = head[u]; i; i = e[i].next) {
register int v = e[i].to;
if (v == fa || v == Son) continue; // 跳过记录的重儿子
Query (v, u, w);
}
}
inline void DFS2 (register int u, register int fa, register bool opt) { // opt = 1,表示不用清除,即重儿子;opt = 0,清除,即轻儿子
for (register int i = head[u]; i; i = e[i].next) {
register int v = e[i].to;
if (v == fa || v == son[u]) continue; // 重儿子要跳过
DFS2 (v, u, 0); // 轻儿子向下递归
}
if (son[u]) DFS2 (son[u], u, 1), Son = son[u]; // 重儿子向下递归
Query (u, fa, 1), Son = 0; // 统计轻儿子的答案
ans[u] = sum; // 记录答案
if (opt == 0) Query (u, fa, -1), sum = 0, maxx = 0; // 轻儿子清除
}
signed main () {
n = read();
for (register int i = 1; i <= n; i ++) {
col[i] = read();
}
for (register int i = 1; i < n; i ++) {
register int u = read(), v = read();
Add (u, v), Add (v, u);
}
DFS1 (1, 0);
DFS2 (1, 0, 1);
for (register int i = 1; i <= n; i ++) {
printf ("%lld ", ans[i]);
}
puts ("");
return 0;
}
例题
其实这个题用线段树可做,但是在某谷上交会暴内存,但是好像还是有人卡过了。
在这儿冲一发 \(dsu\) 的题解~~~
首先我们考虑这个题的性质:
-
一个子树里的若干个结点,其 \(LCA\) 一定在这个子树内。
-
若在 \(a\) 序列中,有一段区间的点,都在同一个子树内,则有 \(\frac {len\times (len - 1)}{2}\) 的区间数,即方案为 \(LCA\) 在这颗子树的方案数。
我们会发现,某个根节点的答案会由其子树贡献过来,这样我们就可以用 \(dsu\) 了。
连续段
若一颗子树内若干个节点能够形成一段区间,那么它们在 \(a\) 序列里是连续的,\(rank(rank[a[i]]= i)\) 就是连续的。
所以我们在便历一颗子树的时候,将每一个节点的 \(rank\) 值用一个 \(tmp\) 数组维护,\(tmp[rank[u]]\) 表示 \(u\) 节点在当前统计过的若干个节点,\(rank[u]\) 作为一个极长区间的左/右端点时,这个区间的长度。
若一颗子树的所有节点放到一个序列里,\(rank\) 值有连续的,就会有一段区间,则这个区间形成的方案数为 \(\frac {len\times (len - 1)} {2}\)。
合并
当我们插入一个数 \(x\)(某个节点的 \(rank\) 值),我们就可以将 \(x - 1\) 和 \(x + 1\) 所在的区间合并起来,并将 \(tmp\) 数组更新。
inline void Query (register int u) {
register int tmp1 = tmp[rank[u] - 1], tmp2 = tmp[rank[u] + 1]; // 得到左边区间的长度,右边区间的长度
register int len = tmp1 + tmp2 + 1; // 合并后区间的长度
tmp[rank[u] - tmp1] = tmp[rank[u] + tmp2] = len; // 将合并后的区间的左右端点的tmp更新
tmpans += len * (len - 1) / 2 - tmp1 * (tmp1 - 1) / 2 - tmp2 * (tmp2 - 1) / 2; // 注意减掉之前统计过的贡献
}
答案统计
但是我们会发现,我们统计的只是 \(LCA\) 在其子树内的方案数,举个栗子:
以 \(3\) 为根节点的子树,我们统计出来的答案是以 \(3,4,5\) 为 \(LCA\) 的答案和。
以 \(6\) 为根节点的子树,我们统计出来的答案是以 \(6\) 为 \(LCA\) 的答案和。
以 \(2\) 为根节点的子树,我们统计出来的答案是以 \(2,3,4,5,6\) 为 \(LCA\) 的答案和。
那么我们将 \(2\) 的答案减去其子树 \(3,6\) 的答案,就可以得出以 \(2\) 为 \(LCA\) 的答案。
最后乘上权值即可。
还有,我们会发现,我们只是统计了区间长度 \(\geqslant 2\) 的贡献,我们只需在求一个 \(sumval\) 加到答案上即可。
代码
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define int long long
using namespace std;
const int maxn = 5e5 + 50, INF = 0x3f3f3f3f;
inline int read () {
int x = 0, w = 1;
char ch = getchar ();
for (; ch < '0' || ch > '9'; ch = getchar ()) if (ch == '-') w = -1;
for (; ch >= '0' && ch <= '9'; ch = getchar ()) x = x * 10 + ch - '0';
return x * w;
}
int n, ans, sumval;
int a[maxn], val[maxn], rank[maxn];
struct Edge {
int to, next;
}e[maxn << 1];
int tot, head[maxn];
inline void Add (register int u, register int v) {
e[++ tot].to = v;
e[tot].next = head[u];
head[u] = tot;
}
int size[maxn], son[maxn];
inline void DFS1 (register int u, register int fa) {
size[u] = 1;
for (register int i = head[u]; i; i = e[i].next) {
register int v = e[i].to;
if (v == fa) continue;
DFS1 (v, u);
size[u] += size[v];
if (size[son[u]] < size[v]) son[u] = v;
}
}
int tmp[maxn], nowans, tmpans, num[maxn], Son;
/*
nowans 表示以u为根节点的子树,以u为LCA的方案数
tmpans 表示以u为根节点的子树,以其子树里的点为LCA的方案数
num[u] 表示以u为根节点的子树,以其子树里的点为LCA的方案数
tmp[u] 表示以u为某个区间的左右端点,这个区间的长度
*/
inline void Query (register int u) {
register int tmp1 = tmp[rank[u] - 1], tmp2 = tmp[rank[u] + 1]; // 得到左边区间的长度,右边区间的长度
register int len = tmp1 + tmp2 + 1; // 合并后区间的长度
tmp[rank[u] - tmp1] = tmp[rank[u] + tmp2] = len; // 将合并后的区间的左右端点的tmp更新
tmpans += len * (len - 1) / 2 - tmp1 * (tmp1 - 1) / 2 - tmp2 * (tmp2 - 1) / 2; // 注意减掉之前统计过的贡献
}
inline void Clear (register int u, register int fa) { // 清空操作
tmp[rank[u]] = 0;
for (register int i = head[u]; i; i = e[i].next) {
register int v = e[i].to;
if (v == fa) continue;
Clear (v, u);
}
}
inline void DFS3 (register int u, register int fa) { // 暴力统计轻儿子答案
Query (u);
for (register int i = head[u]; i; i = e[i].next) {
register int v = e[i].to;
if (v == fa || v == Son) continue;
DFS3 (v, u);
}
}
inline void DFS2 (register int u, register int fa, register bool opt) {
for (register int i = head[u]; i; i = e[i].next) {
register int v = e[i].to;
if (v == fa || v == son[u]) continue;
DFS2 (v, u, 0);
}
if (son[u]) DFS2 (son[u], u, 1), Son = son[u];
DFS3 (u, fa), Son = 0;
nowans = num[u] = tmpans;
for (register int i = head[u]; i; i = e[i].next) {
register int v = e[i].to;
nowans -= num[v]; // 减去子树的num,只要以u为LCA的答案
}
ans += nowans * val[u]; // 统计答案
if (opt == 0) Clear (u, fa), tmpans = 0; // 轻儿子清空
}
signed main () {
n = read();
for (register int v = 2; v <= n; v ++) {
register int u = read();
Add (u, v), Add (v, u);
}
for (register int i = 1; i <= n; i ++) a[i] = read(), rank[a[i]] = i;
for (register int i = 1; i <= n; i ++) val[i] = read(), sumval += val[i];
DFS1 (1, 0);
DFS2 (1, 0, 1);
printf ("%lld\n", ans + sumval);
return 0;
}