【题解】P6071 『MdOI R1』Treequery
思路
清真树论。
树论地考虑祖先后代关系,分讨一下。
用 ST 表处理一下 \(lca(l, r) = u\):
-
\(u, p\) 无祖先后代关系,答案为 \(dis(u, p)\)
-
\(p\) 是 \(u\) 的祖先,答案为 \(dis(u, p)\)
-
\(u\) 是 \(p\) 的祖先,继续分类讨论:
-
\(p\) 的子树内有 \([l, r]\) 中的点,无解
-
否则 \([l, r]\) 中的点都在 \(p\) 的子树外,向上倍增找到 \(p\) 最低且子树中包含 \([l, r]\) 中结点的祖先。
-
最后一步可以主席树维护 \([l, r]\) 内的 dfs 序。
时间复杂度 \(O(n \log^2 n)\)
代码
#include <cstdio>
#include <vector>
#include <iostream>
using namespace std;
const int maxn = 2e5 + 5;
const int lg_sz = 20;
const int t_sz = maxn * 40;
int n, q, cnt;
int fa[maxn], son[maxn], dfn[maxn], top[maxn];
int dep[maxn], sz[maxn], len[maxn], lg[maxn], rt[maxn];
int f[maxn][lg_sz], lc[maxn][lg_sz];
int ls[t_sz], rs[t_sz], sum[t_sz];
vector<int> g[maxn], d[maxn];
void dfs1(int u, int pre)
{
fa[u] = pre;
sz[u] = 1;
dep[u] = dep[pre] + 1;
for (int i = 0; i < g[u].size(); i++)
{
int v = g[u][i], w = d[u][i];
if (v == pre) continue;
f[v][0] = u;
len[v] = len[u] + w;
dfs1(v, u);
sz[u] += sz[v];
if (sz[v] > sz[son[u]]) son[u] = v;
}
}
void dfs2(int u, int t)
{
top[u] = t;
dfn[u] = ++cnt;
if (son[u]) dfs2(son[u], t);
for (int v : g[u])
if ((v != fa[u]) && (v != son[u])) dfs2(v, v);
}
int lca(int u, int v)
{
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]]) swap(u, v);
u = fa[top[u]];
}
return (dep[u] < dep[v] ? u : v);
}
int get_lca(int l, int r)
{
int k = lg[r - l + 1];
return lca(lc[l][k], lc[r - (1 << k) + 1][k]);
}
bool in(int u, int v) { return (dfn[u] >= dfn[v]) && (dfn[u] <= dfn[v] + sz[v] - 1); }
int dis(int u, int v) { return len[u] + len[v] - 2 * len[lca(u, v)]; }
int update(int pre, int l, int r, int p)
{
int k = ++cnt;
sum[k] = sum[pre] + 1, ls[k] = ls[pre], rs[k] = rs[pre];
if (l == r) return k;
int mid = (l + r) >> 1;
if (p <= mid) ls[k] = update(ls[pre], l, mid, p);
else rs[k] = update(rs[pre], mid + 1, r, p);
return k;
}
int query(int k, int pre, int l, int r, int ql, int qr)
{
if ((l >= ql) && (r <= qr)) return sum[k] - sum[pre];
int mid = (l + r) >> 1, sum = 0;
if (ql <= mid) sum += query(ls[k], ls[pre], l, mid, ql, qr);
if (qr > mid) sum += query(rs[k], rs[pre], mid + 1, r, ql, qr);
return sum;
}
int main()
{
int last_ans = 0;
scanf("%d%d", &n, &q);
for (int i = 1, u, v, w; i <= n - 1; i++)
{
scanf("%d%d%d", &u, &v, &w);
g[u].push_back(v), d[u].push_back(w);
g[v].push_back(u), d[v].push_back(w);
}
dfs1(1, 0);
dfs2(1, 1);
cnt = 0;
for (int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
for (int j = 1; (1 << j) <= n; j++)
for (int i = 1; i <= n; i++)
f[i][j] = f[f[i][j - 1]][j - 1];
for (int i = 1; i <= n; i++) lc[i][0] = i;
for (int j = 1; (1 << j) <= n; j++)
for (int i = 1; i + (1 << j) - 1 <= n; i++)
lc[i][j] = lca(lc[i][j - 1], lc[i + (1 << (j - 1))][j - 1]);
for (int i = 1; i <= n; i++) rt[i] = update(rt[i - 1], 1, n, dfn[i]);
while (q--)
{
int p, l, r;
scanf("%d%d%d", &p, &l, &r);
p ^= last_ans, l ^= last_ans, r ^= last_ans;
// printf("get query: %d %d %d\n", p, l, r);
int u = get_lca(l, r);
// printf("get lca(%d, %d) = %d\n", 3, 5, get_lca(3, 5));
if (in(u, p) || (!in(p, u))) last_ans = dis(u, p);
else if (query(rt[r], rt[l - 1], 1, n, dfn[p], dfn[p] + sz[p] - 1)) last_ans = 0;
else
{
// puts("in");
int cur = p;
for (int i = lg[dep[cur]]; i >= 0; i--)
{
int v = f[cur][i];
if (!v) continue;
if (!query(rt[r], rt[l - 1], 1, n, dfn[v], dfn[v] + sz[v] - 1)) cur = v;
// puts("query done");
}
// printf("cur : %d\n", cur);
last_ans = dis(p, f[cur][0]);
}
printf("%d\n", last_ans);
}
return 0;
}