Loading

【题解】P6071 『MdOI R1』Treequery

思路

清真树论。

树论地考虑祖先后代关系,分讨一下。

用 ST 表处理一下 \(lca(l, r) = u\)

  1. \(u, p\) 无祖先后代关系,答案为 \(dis(u, p)\)

  2. \(p\)\(u\) 的祖先,答案为 \(dis(u, p)\)

  3. \(u\)\(p\) 的祖先,继续分类讨论:

    • \(p\) 的子树内有 \([l, r]\) 中的点,无解

    • 否则 \([l, r]\) 中的点都在 \(p\) 的子树外,向上倍增找到 \(p\) 最低且子树中包含 \([l, r]\) 中结点的祖先。

最后一步可以主席树维护 \([l, r]\) 内的 dfs 序。

时间复杂度 \(O(n \log^2 n)\)

代码

#include <cstdio>
#include <vector>
#include <iostream>
using namespace std;

const int maxn = 2e5 + 5;
const int lg_sz = 20;
const int t_sz = maxn * 40;

int n, q, cnt;
int fa[maxn], son[maxn], dfn[maxn], top[maxn];
int dep[maxn], sz[maxn], len[maxn], lg[maxn], rt[maxn];
int f[maxn][lg_sz], lc[maxn][lg_sz];
int ls[t_sz], rs[t_sz], sum[t_sz];
vector<int> g[maxn], d[maxn];

void dfs1(int u, int pre)
{
    fa[u] = pre;
    sz[u] = 1;
    dep[u] = dep[pre] + 1;
    for (int i = 0; i < g[u].size(); i++)
    {
        int v = g[u][i], w = d[u][i];
        if (v == pre) continue;
        f[v][0] = u;
        len[v] = len[u] + w;
        dfs1(v, u);
        sz[u] += sz[v];
        if (sz[v] > sz[son[u]]) son[u] = v;
    }
}

void dfs2(int u, int t)
{
    top[u] = t;
    dfn[u] = ++cnt;
    if (son[u]) dfs2(son[u], t);
    for (int v : g[u])
        if ((v != fa[u]) && (v != son[u])) dfs2(v, v);
}

int lca(int u, int v)
{
    while (top[u] != top[v])
    {
        if (dep[top[u]] < dep[top[v]]) swap(u, v);
        u = fa[top[u]];
    }
    return (dep[u] < dep[v] ? u : v);
}

int get_lca(int l, int r)
{
    int k = lg[r - l + 1];
    return lca(lc[l][k], lc[r - (1 << k) + 1][k]);
}

bool in(int u, int v) { return (dfn[u] >= dfn[v]) && (dfn[u] <= dfn[v] + sz[v] - 1); }

int dis(int u, int v) { return len[u] + len[v] - 2 * len[lca(u, v)]; }

int update(int pre, int l, int r, int p)
{
    int k = ++cnt;
    sum[k] = sum[pre] + 1, ls[k] = ls[pre], rs[k] = rs[pre];
    if (l == r) return k;
    int mid = (l + r) >> 1;
    if (p <= mid) ls[k] = update(ls[pre], l, mid, p);
    else rs[k] = update(rs[pre], mid + 1, r, p);
    return k;
}

int query(int k, int pre, int l, int r, int ql, int qr)
{
    if ((l >= ql) && (r <= qr)) return sum[k] - sum[pre];
    int mid = (l + r) >> 1, sum = 0;
    if (ql <= mid) sum += query(ls[k], ls[pre], l, mid, ql, qr);
    if (qr > mid) sum += query(rs[k], rs[pre], mid + 1, r, ql, qr);
    return sum;
}

int main()
{
    int last_ans = 0;
    scanf("%d%d", &n, &q);
    for (int i = 1, u, v, w; i <= n - 1; i++)
    {
        scanf("%d%d%d", &u, &v, &w);
        g[u].push_back(v), d[u].push_back(w);
        g[v].push_back(u), d[v].push_back(w);
    }
    dfs1(1, 0);
    dfs2(1, 1);
    cnt = 0;
    for (int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
    for (int j = 1; (1 << j) <= n; j++)
        for (int i = 1; i <= n; i++)
            f[i][j] = f[f[i][j - 1]][j - 1];
    for (int i = 1; i <= n; i++) lc[i][0] = i;
    for (int j = 1; (1 << j) <= n; j++)
        for (int i = 1; i + (1 << j) - 1 <= n; i++)
            lc[i][j] = lca(lc[i][j - 1], lc[i + (1 << (j - 1))][j - 1]);
    for (int i = 1; i <= n; i++) rt[i] = update(rt[i - 1], 1, n, dfn[i]);
    while (q--)
    {
        int p, l, r;
        scanf("%d%d%d", &p, &l, &r);
        p ^= last_ans, l ^= last_ans, r ^= last_ans;
        // printf("get query: %d %d %d\n", p, l, r);
        int u = get_lca(l, r);
        // printf("get lca(%d, %d) = %d\n", 3, 5, get_lca(3, 5));
        if (in(u, p) || (!in(p, u))) last_ans = dis(u, p);
        else if (query(rt[r], rt[l - 1], 1, n, dfn[p], dfn[p] + sz[p] - 1)) last_ans = 0;
        else
        {
            // puts("in");
            int cur = p;
            for (int i = lg[dep[cur]]; i >= 0; i--)
            {
                int v = f[cur][i];
                if (!v) continue;
                if (!query(rt[r], rt[l - 1], 1, n, dfn[v], dfn[v] + sz[v] - 1)) cur = v;
                // puts("query done");
            }
            // printf("cur : %d\n", cur);
            last_ans = dis(p, f[cur][0]);
        }
        printf("%d\n", last_ans);
    }
    return 0;
}
posted @ 2023-01-13 08:39  kymru  阅读(14)  评论(0编辑  收藏  举报