P6374 「StOI-1」树上询问
P6374 「StOI-1」树上询问
Description
给定一颗 \(n\) 个节点的树,有 \(q\) 次询问。
每次询问给一个参数三元组 \((a, b, c)\) ,求有多少个 \(i\) 满足这棵树在以 \(i\) 为根的情况下 \(a\) 和 \(b\) 的 \(\text{LCA}\) 是 \(c\)。
其中,\(1 \leq n \leq 5 \times 10^{5},\ 1 \leq q \leq 2 \times 10^{5}\)
Solution
看到数据范围,可以想到大概是 \(O(q \log n)\) 的复杂度,而且要求 \(\text{LCA}\),大概是需要我们先 \(O(n)\) 处理一下这棵树,然后寻找一些 \(\text{LCA}\) 性质关系来解决这个问题。
其实这个题目给的三个样例挺好的,我们完全可以在手玩样例的时候隐隐约约感受到正解,只不过需要再加上一些细节处理。
我们拿这一颗树为例子吧。
如果三元组是 \((15, 19, 16)\) 也就是说 \(z\) 在 \(x,\ y\) 的路径上面并且 \(z\) 不是 \(x, \ y\) 的最近公共祖先,那么我们把 \(16\) 给拎起来,容易发现,答案的个数就是除了 \(16\) 到 \(19\) 这个路径上的 \(16\) 的子树,其余的子树的大小之和。
如果三元组是 \((1, 8, 3)\) 也就是 \(z\) 不在 \(x, \ y\) 的路径之上,那么无论我们让哪个节点作为根节点都不会成立,所以答案就是 0
如果三元组是 \((9, 12, 1)\) ,也就是说 \(z\) 正好就是 \(x, \ y\) 的最近公共祖先,那么答案就是整棵树除了公共祖先分别到 \(x, \ y\) 的子树的其余子树的大小之和。
再处理一些细节问题,比如说 \((20, 20, 20)\) 和 \((1, 12, 1)\) 这样的三元组。
Code
#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 5;
int n, q, cnt, head[N], dep[N], fa[N][20], siz[N], lg[N];
struct Edge {
int to, nxt;
}e[N << 1];
inline int read() {
int x = 0, f = 1;
char c = getchar();
while (!isdigit(c)) {
if (c == '-') f = -1;
c = getchar();
}
while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
return x * f;
}
inline void add(int x, int y) {
e[++cnt].to = y;
e[cnt].nxt = head[x];
head[x] = cnt;
}
inline void dfs(int x, int f) {
dep[x] = dep[f] + 1;
siz[x] = 1;
fa[x][0] = f;
for (int i = 1; i <= lg[dep[x]]; ++i) fa[x][i] = fa[fa[x][i - 1]][i - 1];
for (int i = head[x]; i ; i = e[i].nxt) {
int to = e[i].to;
if (to == f) continue;
dfs(to, x);
siz[x] += siz[to];
}
}
inline int lca(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
while (dep[x] != dep[y]) x = fa[x][lg[dep[x] - dep[y]]];
if (x == y) return x;
for (int i = lg[dep[x]]; i >= 0; --i) if (fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
inline int find_son(int x, int y) {
if (x == y) return 0;
while (dep[x] != dep[y] + 1) x = fa[x][lg[dep[x] - dep[y] - 1]];
return x;
}
int main() {
n = read(), q = read();
int x, y;
for (int i = 1; i <= n - 1; ++i) {
x = read(), y = read();
add(x, y);
add(y, x);
}
for (int i = 2; i <= n; ++i) lg[i] = lg[i >> 1] + 1;
dfs(1, 0);
int z;
while (q--) {
x = read(), y = read(), z = read();
if (x == y && z == x) {
printf("%d\n", n);
continue;
}
int l = lca(x, y);
if (z == l) {
int fx = find_son(x, z);
int fy = find_son(y, z); // 处理相同的特殊情况
printf("%d\n", siz[1] - siz[fx] - siz[fy]);
} else if ((lca(x, z) == z || lca(y, z) == z) && dep[l] <= dep[z]) {
if (lca(z, y) == l) {
int fx = find_son(x, z);
printf("%d\n", siz[z] - siz[fx]);
}
else {
int fy = find_son(y, z);
printf("%d\n", siz[z] - siz[fy]);
}
} else printf("0\n");
}
system("pause");
return 0;
}
Summary
其中判断树上三个点 \(x, y, z\) ,\(z\) 是否位于 \(x, y\) 两个点间的路径之上。我们可以通过 \((\text{lca}(x, z) == z \ |\ | \ \text{lca}(y, z) == z)\ \&\& \ dep[z] >= dep[\text{lca}(x,y)]\) 来判断。
或者也可以用 \(dis\) 数组代表根到每个点的距离,对于 \(x, y, z\) 中 \(z\) 是否位于 \(x, \ y\) 路径之上可以用 \(dis[x] + dis[y] - 2 * dis[lca(a,b)] == dis[a] + dis[c] - 2 * dis[lca(a, c)] + dis[b] + dis[c] - 2 * dis[lca(b, c)]\) 来判断。