「解题报告」CF983E NN country
水点简单数据结构题!
考虑从两个点开始往上跳,每次肯定尽可能跳到最浅的点。两个点跳到再跳一步就能到达 lca 的位置的时候,此时再看看有没有路径连接这两个点,如果有那么一步就可以跳到,否则就要跳到 lca 再跳一步,两步跳到。跳的过程显然可以用倍增处理。
然后我们考虑处理出每个点能跳到的最浅的点。假如现在处理 \(u\) 点,有一条路径 \(x-y\) 满足 \(x\) 在 \(u\) 子树内,那么 \(u\) 能通过这条路径跳到的最浅的点为 \(\mathrm{lca}(u, y)\)。那么我们相当于要求深度最小的 lca,我们只需要找出 dfn 序最小的与 dfn 序最大的两个点即可。然后做完了。
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 200005;
int n, m, q;
vector<int> e[MAXN];
int dfn[MAXN], ed[MAXN], idf[MAXN], fa[MAXN][22], dep[MAXN], dcnt;
void dfs(int u, int pre) {
dfn[u] = ++dcnt, idf[dcnt] = u, fa[u][0] = pre, dep[u] = dep[pre] + 1;
for (int i = 1; i <= 20; i++)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
for (int v : e[u]) {
dfs(v, u);
}
ed[u] = dcnt;
}
int lca(int u, int v) {
if (dep[u] < dep[v]) swap(u, v);
for (int i = 20; i >= 0; i--) if (dep[fa[u][i]] >= dep[v]) u = fa[u][i];
if (u == v) return u;
for (int i = 20; i >= 0; i--) if (fa[u][i] != fa[v][i]) u = fa[u][i], v = fa[v][i];
return fa[u][0];
}
struct SegmentTree {
struct Node {
int lc, rc;
int sum;
} t[MAXN * 48];
int tot;
void insert(int d, int &p, int l = 1, int r = n) {
if (!p) p = ++tot, t[p].sum++;
else t[++tot] = t[p], p = tot, t[p].sum++;
if (l == r) return;
int mid = (l + r) >> 1;
if (d <= mid) insert(d, t[p].lc, l, mid);
else insert(d, t[p].rc, mid + 1, r);
}
int query(int a, int b, int p, int l = 1, int r = n) {
if (!p) return 0;
if (a <= l && r <= b) return t[p].sum;
int mid = (l + r) >> 1;
if (b <= mid) return query(a, b, t[p].lc, l, mid);
if (a > mid) return query(a, b, t[p].rc, mid + 1, r);
return query(a, b, t[p].lc, l, mid) + query(a, b, t[p].rc, mid + 1, r);
}
} st;
int root[MAXN];
int mnd[MAXN], mxd[MAXN];
int x[MAXN], y[MAXN];
vector<int> t[MAXN];
void dfs2(int u, int pre) {
for (int v : e[u]) {
dfs2(v, u);
mxd[u] = max(mxd[u], mxd[v]);
mnd[u] = min(mnd[u], mnd[v]);
}
}
int f[MAXN][22];
int main() {
scanf("%d", &n);
for (int i = 2; i <= n; i++) {
int p; scanf("%d", &p);
e[p].push_back(i);
}
for (int i = 1; i <= n; i++) {
mnd[i] = n + 1, mxd[i] = 0;
}
dfs(1, 0);
scanf("%d", &m);
for (int i = 1; i <= m; i++) {
int a, b; scanf("%d%d", &a, &b);
x[i] = a, y[i] = b;
t[dfn[a]].push_back(dfn[b]), t[dfn[b]].push_back(dfn[a]);
mxd[a] = max(mxd[a], dfn[b]), mnd[a] = min(mnd[a], dfn[b]);
mxd[b] = max(mxd[b], dfn[a]), mnd[b] = min(mnd[b], dfn[a]);
}
dfs2(1, 0);
for (int i = 1; i <= n; i++) {
root[i] = root[i - 1];
for (int j : t[i]) {
st.insert(j, root[i]);
}
}
for (int u = 1; u <= n; u++) {
f[u][0] = u;
if (mxd[u] != 0) {
int v = idf[mxd[u]];
int l = lca(u, v);
if (dep[l] < dep[f[u][0]]) f[u][0] = l;
}
if (mnd[u] != n + 1) {
int v = idf[mnd[u]];
int l = lca(u, v);
if (dep[l] < dep[f[u][0]]) f[u][0] = l;
}
}
for (int j = 1; j <= 20; j++) {
for (int i = 1; i <= n; i++) {
f[i][j] = f[f[i][j - 1]][j - 1];
}
}
scanf("%d", &q);
for (int i = 1; i <= q; i++) {
int u, v; scanf("%d%d", &u, &v);
int l = lca(u, v);
int ans = 0;
if (l == u || l == v) {
if (l == u) swap(u, v);
for (int i = 20; i >= 0; i--) if (dep[f[u][i]] > dep[v]) u = f[u][i], ans += 1 << i;
if (dep[f[u][0]] <= dep[v]) printf("%d\n", ans + 1);
else printf("-1\n");
} else {
for (int i = 20; i >= 0; i--) if (dep[f[u][i]] > dep[l]) u = f[u][i], ans += 1 << i;
for (int i = 20; i >= 0; i--) if (dep[f[v][i]] > dep[l]) v = f[v][i], ans += 1 << i;
if (dep[f[u][0]] > dep[l]) printf("-1\n");
else if (dep[f[v][0]] > dep[l]) printf("-1\n");
else {
int c = st.query(dfn[u], ed[u], root[ed[v]]) - st.query(dfn[u], ed[u], root[dfn[v] - 1]);
if (c) ans += 1;
else ans += 2;
printf("%d\n", ans);
}
}
}
return 0;
}