「学习笔记」dsu on tree
dsu on tree
简介
可以用来处理一些树上的问题,一般有2个特征:
- 询问子树上的信息。
- 没有修改。
思路
考虑暴力做法,对于每个子树,遍历一遍求答案,时间复杂度是 \(O(n^2)\) 的。
但是其实在求当前子树的答案时,它的信息父亲也是要用的,但是我们不能全继承上去,因为在遍历它的儿子求答案时,是相互独立的,需要清空。所以我们考虑怎么样能使优化最大化,显然是继承重儿子。
重儿子的预处理就是轻重链剖分。
因此,先遍历所有轻儿子,统计完答案后清空,再遍历重儿子,统计答案,继承给父亲。
时间复杂度
类似于树剖,因为重儿子继承了,所以是 \(O(n\log n)\)。
例题
1. CF600E Lomsat gelral
题意:
给出一个树,求出每个节点的子树中出现次数最多的颜色的编号和。
板子题,开个数组记录一下每个颜色的个数,暴力统计即可。
#include <iostream>
#include <cstdio>
#include <vector>
#define ll long long
using namespace std;
const int N = 1e5 + 5;
int n, c[N];
vector <int> g[N];
int siz[N], son[N];
void dfs1(int u, int fa)
{
siz[u] = 1;
for(int i = 0; i < g[u].size(); i++)
{
int v = g[u][i];
if(v == fa) continue;
dfs1(v, u);
siz[u] += siz[v];
if(siz[v] > siz[son[u]]) son[u] = v;
}
}
int Son, mx, cnt[N];
ll sum, ans[N];
void add(int u, int fa, int val)
{
cnt[c[u]] += val;
if(cnt[c[u]] > mx) mx = cnt[c[u]], sum = c[u];
else if(cnt[c[u]] == mx) sum += c[u];
for(int i = 0; i < g[u].size(); i++)
{
int v = g[u][i];
if(v == fa || v == Son) continue;
add(v, u, val);
}
}
void dfs2(int u, int fa, int op)
{
for(int i = 0; i < g[u].size(); i++)
{
int v = g[u][i];
if(v == fa || v == son[u]) continue;
dfs2(v, u, 0);
}
if(son[u]) dfs2(son[u], u, 1), Son = son[u];
add(u, fa, 1), Son = 0;
ans[u] = sum;
if(!op) add(u, fa, -1), mx = 0, sum = 0;
}
int rd()
{
int x = 0;
char c = getchar();
while(!isdigit(c)) c = getchar();
while(isdigit(c)) x = x * 10 + c - '0', c = getchar();
return x;
}
int main()
{
n = rd();
for(int i = 1; i <= n; i++)
c[i] = rd();
for(int i = 1; i < n; i++)
{
int u = rd(), v = rd();
g[u].push_back(v);
g[v].push_back(u);
}
dfs1(1, 0);
dfs2(1, 0, 0);
for(int i = 1; i <= n; i++)
printf("%lld ", ans[i]);
puts("");
return 0;
}
2. CF375D Tree and Queries
题意:
给定一颗树,多次询问一颗子树中出现次数 ≥k 的颜色有多少种。
与上一道题类似,但要离线。
先把询问处理一下,把每个子树的询问挂到根上,统计答案时一起算。
#include <iostream>
#include <cstdio>
#include <vector>
using namespace std;
const int N = 1e5 + 5;
template <typename T>
void read(T &x)
{
x = 0;
char c = getchar();
while(!isdigit(c)) c = getchar();
while(isdigit(c)) x = x * 10 + c - '0', c = getchar();
return;
}
template <typename T>
void write(T x)
{
if(x > 9) write(x / 10);
putchar(x % 10 + '0');
}
int n, m, col[N];
vector <int> g[N];
struct query
{
int k, id;
};
vector <query> q[N];
int siz[N], son[N], f[N];
void dfs1(int u, int fa)
{
f[u] = fa;
siz[u] = 1;
for(int i = 0; i < g[u].size(); i++)
{
int v = g[u][i];
if(v == fa) continue;
dfs1(v, u);
siz[u] += siz[v];
if(siz[v] > siz[son[u]]) son[u] = v;
}
}
int cnt[N], sum, ans[N], d[N];
void add(int u, int val, int Son)
{
cnt[col[u]] += val;
if(val == 1) d[cnt[col[u]]]++;
else d[cnt[col[u]] + 1]--;
for(int i = 0; i < g[u].size(); i++)
{
int v = g[u][i];
if(v == f[u] || v == Son) continue;
add(v, val, Son);
}
}
void dfs2(int u, int op)
{
for(int i = 0; i < g[u].size(); i++)
{
int v = g[u][i];
if(v == f[u] || v == son[u]) continue;
dfs2(v, 0);
}
if(son[u]) dfs2(son[u], 1);
add(u, 1, son[u]);
for(int i = 0; i < q[u].size(); i++)
ans[q[u][i].id] = d[q[u][i].k];
if(!op) add(u, -1, 0), sum = 0;
}
int main()
{
read(n), read(m);
for(int i = 1; i <= n; i++)
read(col[i]);
for(int i = 1, u, v; i < n; i++)
{
read(u), read(v);
g[u].push_back(v);
g[v].push_back(u);
}
for(int i = 1, u; i <= m; i++)
{
query t;
read(u), read(t.k);
t.id = i;
q[u].push_back(t);
}
dfs1(1, 0);
dfs2(1, 0);
for(int i = 1; i <= m; i++)
write(ans[i]), puts("");
return 0;
}
参考资料
$$A\ drop\ of\ tear\ blurs\ memories\ of\ the\ past.$$