基础树上问题
五天没见过逆向题的 flag 了,甚是自闭,遂决定切换一下心情水点 acm 的小水题,顺带记录一下
LCA模板
发现博客里没有 LCA 模板,这不太好,放上板子方便 acv(
在线计算 LCA 一般使用倍增,也就是跳 1,2,4,8,16,321,2,4,8,16,32 …… 不过在这不是按从小到大跳,而是从大向小跳,即按……32,16,8,4,2,132,16,8,4,2,1来跳,如果大的跳不过去,再把它调小
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 1e7 + 10;
const int maxm = 1e4 + 10;
inline int read() {
int x = 0, k = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') k = -1;
for (; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0';
return x * k;
}
int n, m, s, x, y, a, b, ans;
int cnt, head[maxn];
int lg[maxn], dep[maxn], f[maxn][30];
struct node{
int to, nxt;
}e[maxn];
inline void add(int u, int v) {
e[++cnt].to = v;
e[cnt].nxt = head[u];
head[u] = cnt;
}
inline void dfs(int u, int fa) {
dep[u] = dep[fa] + 1; // dep[x] 为 x 节点的深度
f[u][0] = fa; // f[i][j] 表示 i 节点的 2^j 级祖先
for (int i = 1; (1 << i) <= dep[u]; i++) f[u][i] = f[f[u][i - 1]][i - 1]; // 意思是 f 的 2^i 祖先等于 f 的 2^(i-1) 祖先的 2^(i-1) 祖先
for (int i = head[u]; i; i = e[i].nxt)
if (e[i].to != fa) dfs(e[i].to, u);
}
inline int lca(int x, int y) {
if (dep[x] < dep[y]) swap(x, y); // 假设 x 是最深的节点
while (dep[x] > dep[y]) x = f[x][lg[dep[x] - dep[y]] - 1]; // 让两个节点一样深
if (x == y) return x;
for (int i = lg[dep[x]] - 1; i >= 0; i--)
if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i]; // 如果不一样那么肯定没有到达 lca ,因为两个节点的 lca 向上的节点就是一样的了
return f[x][0];
}
int main() {
n = read(); m = read(); s = read();
for (int i = 1; i < n; i++) {
x = read(); y = read();
add(x, y); add(y, x);
}
dfs(s, 0);
for (int i = 1; i <= n; i++)
lg[i] = lg[i - 1] + (1 << lg[i - 1] == i);
for (int i = 1; i <= m; i++) {
a = read(); b = read();
ans = lca(a, b);
printf("%d\n", ans);
}
return 0;
}
acv版无注释精华良品:
int n, m, s, x, y, a, b, ans;
int cnt, head[maxn];
int lg[maxn], dep[maxn], f[maxn][30];
struct node{
int to, nxt;
}e[maxn];
inline void add(int u, int v) {
e[++cnt].to = v;
e[cnt].nxt = head[u];
head[u] = cnt;
}
inline void dfs(int u, int fa) {
dep[u] = dep[fa] + 1;
f[u][0] = fa;
for (int i = 1; (1 << i) <= dep[u]; i++) f[u][i] = f[f[u][i - 1]][i - 1];
for (int i = head[u]; i; i = e[i].nxt)
if (e[i].to != fa) dfs(e[i].to, u);
}
inline int lca(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
while (dep[x] > dep[y]) x = f[x][lg[dep[x] - dep[y]] - 1];
if (x == y) return x;
for (int i = lg[dep[x]] - 1; i >= 0; i--)
if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
int main() {
n = read(); m = read(); s = read();
for (int i = 1; i < n; i++) {
x = read(); y = read();
add(x, y); add(y, x);
}
dfs(s, 0);
for (int i = 1; i <= n; i++)
lg[i] = lg[i - 1] + (1 << lg[i - 1] == i);
for (int i = 1; i <= m; i++) {
a = read(); b = read();
ans = lca(a, b);
printf("%d\n", ans);
}
return 0;
}
LCA的应用
1、紧急集合
题目链接:P4281 [AHOI2008]紧急集合 / 聚会
题目大意:
给定一棵 N 个节点的树以及 M 次询问,每次询问给出 x, y, z 三个节点,要求在树上找一个点 p 使得 c = dist(x,p)+dist(y,p)+dist(z,p) 取最小值,每一次询问输出满足条件的 p 和此时的最小的 c
通过瞪眼法可以得到:
1、三个点两两之间的 LCA 一定有两个点相同
2、如果只有 2 个点相同,那么聚集点就一定是剩下一个 LCA
3、如果 3 个点的 LCA 都相同,那么聚集点就是这个 LCA
综上所述,集结点应该在三个点的三个最近公共祖先中深度最深的那个点上
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 500010 * 2 + 10;
const int maxm = 500010;
int n, m, cnt, ans;
struct node{
int to, nxt;
}e[maxn];
int num1, num2, num, num3, x, y, z;
int head[maxm], dep[maxm], lg[maxm], f[maxm][30];
inline int read() {
int x = 0, k = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') k = -1;
for (; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0';
return x * k;
}
inline void add(int u, int v) {
e[++cnt].to = v;
e[cnt].nxt = head[u];
head[u] = cnt;
}
inline void dfs(int u, int fa) {
dep[u] = dep[fa] + 1;
f[u][0] = fa;
for (int i = 1; (1 << i) <= dep[u]; i++) f[u][i] = f[f[u][i - 1]][i - 1];
for (int i = head[u]; i; i = e[i].nxt)
if (e[i].to != fa) dfs(e[i].to, u);
}
inline int lca(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
while (dep[x] > dep[y]) x = f[x][lg[dep[x] - dep[y]] - 1];
if (x == y) return x;
for (int i = lg[dep[x]] - 1; i >= 0; i--)
if (f[x][i] != f[y][i]) {
x = f[x][i]; y = f[y][i];
}
return f[x][0];
}
int main() {
n = read(); m = read();
for (int i = 1; i <= n; i++)
lg[i] = lg[i - 1] + (1 << lg[i - 1] == i);
for (int i = 1; i < n; i++) {
register int a, b;
a = read(); b = read();
add(a, b); add(b, a);
}
dfs(1, 0);
for (int i = 1; i <= m; i++) {
x = read(); y = read(); z = read();
int r1 = lca(x, y), r2 = lca(x, z), r3 = lca(y, z);
if (dep[r1] >= dep[r2] && dep[r1] >= dep[r3]) {
num1 = x; num2 = y; num3 = z; num = r1;
} else if (dep[r2] >= dep[r1] && dep[r2] >= dep[r3]) {
num1 = x; num2 = z; num3 = y; num = r2;
} else if (dep[r3] >= dep[r1] && dep[r3] >= dep[r2]) {
num1 = y; num2 = z; num3 = x; num = r3;
}
int r = lca(num1, num3);
ans = dep[num1] + dep[num2] - 2 * dep[num] + dep[num3] + dep[num] - 2 * dep[r];
printf("%d %d\n", num, ans);
}
return 0;
}
2、仓鼠找sugar
题目链接:P3398 仓鼠找sugar
如果两个起点的距离 + 两个终点的距离 >= 两条路径的长度和
那么两条路径有一部分一定是重合的(或者说一定是存在公共点的)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 100010;
const int maxm = 1e4 + 10;
int n, q, a, b, c, d, cnt;
struct node{
int to, nxt;
}e[maxn * 2];
int head[maxn], dep[maxn], lg[maxn], f[maxn][30];
inline int read() {
int x = 0, k = 1; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') k = -1;
for (; isdigit(ch); ch = getchar()) x = x * 10 + ch - '0';
return x * k;
}
inline void add(int u, int v) {
e[++cnt].to = v;
e[cnt].nxt = head[u];
head[u] = cnt;
}
inline void dfs(int u, int fa) {
dep[u] = dep[fa] + 1;
f[u][0] = fa;
for (int i = 1; (1 << i) <= dep[u]; i++)
f[u][i] = f[f[u][i - 1]][i - 1];
for (int i = head[u]; i; i = e[i].nxt)
if (e[i].to != fa) dfs(e[i].to, u);
}
inline int lca(int x, int y) {
if (dep[x] < dep[y]) swap(x, y);
while (dep[x] > dep[y]) x = f[x][lg[dep[x] - dep[y]] - 1];
if (x == y) return x;
for (int i = lg[dep[x]] - 1; i >= 0; i--)
if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
int main() {
n = read(); q = read();
for (int i = 1; i < n; i++) {
register int u, v;
u = read(); v = read();
add(u, v); add(v, u);
}
dfs(1, 0);
for (int i = 1; i <= n; i++)
lg[i] = lg[i - 1] + (1 << lg[i - 1] == i);
for (int i = 1; i <= q; i++) {
a = read(); b = read(); c = read(); d = read();
register int ab = dep[a] + dep[b] - 2 * dep[lca(a, b)];
register int cd = dep[c] + dep[d] - 2 * dep[lca(c, d)];
register int ac = dep[a] + dep[c] - 2 * dep[lca(a, c)];
register int bd = dep[b] + dep[d] - 2 * dep[lca(b, d)];
if (ab + cd >= ac + bd) printf("Y\n");
else printf("N\n");
}
return 0;
}