【题解】P5384 [Cnoi2019] 雪松果树
看到一年前的 40pts 突然想填坑,就当顺便回忆一下怎么写题解了。
思路
线段树合并 / dsu on tree / 长链剖分 / vector + 二分 / dfs 序 + 树状数组(差分)。
线段树合并
复杂度 \(O(n \log n)\),可以卡掉。
考虑用线段树合并维护每个结点的子树内所有深度的出现次数。
dsu on tree
和线段树合并同理,也可以卡掉。
长链剖分
复杂度 \(O(n)\),勉强能过。
考虑求一个结点的 k-cousin 分成:
-
求结点的 k-father
-
求该结点的 k-son 个数
第一步可以在 dfs 的时候顺便开一个 \(O(n)\) 的栈维护。
第二步是经典的长链剖分问题,直接长剖 dp \(O(n)\) 维护就行。
vector + 二分
考虑对每个深度开一个 vector,把所有该深度的结点按 dfs 序塞进去。
查询的时候直接二分左右端点数长度,复杂度也是 \(O(n \log n)\).
dfs 序 + 树状数组(差分)
转化成 dfs 序上的问题等价于问:指定的区间内有多少个数等于给定的值 \(k\)?
直接 BIT 随便维护一下,复杂度也是假的。
其实不需要 BIT,考虑在 dfs 的同时维护每个深度的出现次数 \(cnt_d\),查询某棵子树内深度为 \(d\) 的结点个数只需要在用遍历这棵子树前后的 \(cnt_d\) 差分。
时间复杂度也是 \(O(n)\),应该比长剖做法好写。
好水,但我之前为啥卡常卡了十发。
代码
写的无脑长剖做法。
#include <cstdio>
#include <vector>
using namespace std;
#define il inline
const int maxn = 1e6 + 5;
const int maxq = 1e6 + 5;
int n, q;
int f[maxn];
int top, stk[maxn];
int son[maxn], dep[maxn];
int ans[maxq];
int *cur = f, *dp[maxn];
vector<int> g[maxn], qk[maxn], qid[maxn];
il int read()
{
int res = 0;
char ch = getchar();
while ((ch < '0') || (ch > '9')) ch = getchar();
while ((ch >= '0') && (ch <= '9')) res = res * 10 + ch - '0', ch = getchar();
return res;
}
il void dfs1(int u)
{
stk[++top] = u;
for (int i = 0, v; i < qk[u].size(); i++)
{
if (top > qk[u][i])
{
v = stk[top - qk[u][i]];
qk[v].push_back(qk[u][i]), qid[v].push_back(qid[u][i]);
}
}
qk[u].clear(), qid[u].clear();
for (int v : g[u])
{
dfs1(v);
if (dep[v] > dep[son[u]]) son[u] = v;
}
dep[u] = dep[son[u]] + 1, top--;
}
il void dfs2(int u)
{
dp[u][0] = 1;
if (son[u]) dp[son[u]] = dp[u] + 1, dfs2(son[u]);
for (int v : g[u])
{
if (v == son[u]) continue;
dp[v] = cur, cur += dep[v], dfs2(v);
for (int i = 1; i <= dep[v]; i++) dp[u][i] += dp[v][i - 1];
}
for (int i = 0; i < qk[u].size(); i++) ans[qid[u][i]] = dp[u][qk[u][i]] - 1;
}
int main()
{
n = read(), q = read();
for (int i = 2; i <= n; i++) g[read()].push_back(i);
for (int i = 1, u; i <= q; i++)
{
u = read();
qk[u].push_back(read()), qid[u].push_back(i);
}
dfs1(1);
dp[1] = cur, cur += dep[1], dfs2(1);
for (int i = 1; i <= q; i++) printf("%d ", ans[i]); puts("");
return 0;
}