CodeForces 2033G Sakurako and Chefir 题解

题意:给定一棵树,多次询问在以节点 \(v\)\(k\) 级祖先为根的子树内,从点 \(v\) 出发的最长简单路径。

显然可以将所求路径分为终点在 \(v\) 的子树内和终点在 \(v\) 的子树外,前者容易处理。

考虑其他节点贡献的路径,容易想到枚举 \(v\) 每个合法祖先 \(u\) 作为 LCA,查询 \(v\) 所在的子树与 \(u\) 的其他子树之间的贡献。贡献可以使用主席树处理。具体而言,由树上路径计算方式 \(d = dis_u + dis_v - 2 \cdot dis_{lca}\),设 \(rt_u\) 表示记录节点 \(u\) 的子树内所有点的深度的主席树,查询时减去 \(v\) 所在子树的所有点的深度,并查询最大深度。

于是得到一个 \(O(qn \log n)\) 的做法,考虑优化。

发现每次查询时跳了很多祖先,而对于一个确定的祖先与其确定的子树,没有必要查询多次。使用树上启发式合并:预处理重儿子的答案作为该节点的权值,树剖查询链最大值,跳轻链时再使用前述方法。优化至 \(O(q \log^2 n)\),可以通过。

#include <algorithm>
#include <iostream>
#include <vector>

using namespace std;

const int inf = 1e9;

int lg[200005];

int n, q;
vector<int> vec[200005];

int rt[200005];
struct node {
    int ls, rs, val;
} d[16000005];
int cnt;
static inline int clone(int p) {
    d[++cnt] = d[p];
    return cnt;
}
static inline int insert(int x, int s, int t, int c, int p) {
    p = clone(p);
    if (s == t) {
        d[p].val += c;
        return p;
    }
    int mid = (s + t) >> 1;
    if (x <= mid)
        d[p].ls = insert(x, s, mid, c, d[p].ls);
    else
        d[p].rs = insert(x, mid + 1, t, c, d[p].rs);
    d[p].val = d[d[p].ls].val + d[d[p].rs].val;
    return p;
}
static inline int merge(int s, int t, int u, int v) {
    if (!u || !v)
        return clone(u | v);
    int p = ++cnt;
    if (s == t) {
        d[p].val = d[u].val + d[v].val;
        return p;
    }
    int mid = (s + t) >> 1;
    d[p].ls = merge(s, mid, d[u].ls, d[v].ls);
    d[p].rs = merge(mid + 1, t, d[u].rs, d[v].rs);
    d[p].val = d[d[p].ls].val + d[d[p].rs].val;
    return p;
}
static inline int query(int s, int t, int u, int v) {
    if (s == t)
        return d[v].val - d[u].val > 0 ? s : -inf;
    int mid = (s + t) >> 1;
    if (d[d[v].rs].val - d[d[u].rs].val > 0)
        return query(mid + 1, t, d[u].rs, d[v].rs);
    return query(s, mid, d[u].ls, d[v].ls);
}

int dep[200005];
int st[200005][20];
int siz[200005];
int son[200005];
int top[200005];
int dfn[200005], dfn_clock;
int nfd[200005];
int pre[200005][20];
static inline void dfs(int u, int fa) {
    dep[u] = dep[fa] + 1;
    st[u][0] = fa;
    for (int i = 1; i <= 18; ++i)
        st[u][i] = st[st[u][i - 1]][i - 1];
    siz[u] = 1;
    rt[u] = insert(dep[u], 1, n, 1, rt[u]);
    for (auto v : vec[u]) {
        if (v == fa)
            continue;
        dfs(v, u);
        siz[u] += siz[v];
        if (siz[v] > siz[son[u]])
            son[u] = v;
        rt[u] = merge(1, n, rt[u], rt[v]);
    }
}
static inline void dfs2(int u) {
    nfd[dfn[u] = ++dfn_clock] = u;
    if (!son[u]) {
        pre[dfn[u]][0] = -inf;
        return;
    }
    top[son[u]] = top[u];
    dfs2(son[u]);
    pre[dfn[u]][0] = query(1, n, rt[son[u]], rt[u]) - 2 * dep[u];
    for (auto v : vec[u]) {
        if (v == st[u][0] || v == son[u])
            continue;
        top[v] = v;
        dfs2(v);
    }
}
static inline int jump(int x, int k) {
    for (int i = 0; i <= 18; ++i)
        if ((k >> i) & 1)
            x = st[x][i];
    return x;
}
static inline int LCA(int u, int v) {
    if (dep[u] < dep[v])
        swap(u, v);
    u = jump(u, dep[u] - dep[v]);
    if (u == v)
        return u;
    for (int i = 18; i >= 0; --i)
        if (st[u][i] != st[v][i]) {
            u = st[u][i];
            v = st[v][i];
        }
    return st[u][0];
}

static inline void build() {
    for (int i = 2; i <= n; ++i)
        lg[i] = lg[i >> 1] + 1;
    for (int j = 1; j <= 18; ++j)
        for (int i = 1; i + (1 << j) - 1 <= n; ++i)
            pre[i][j] = max(pre[i][j - 1], pre[i + (1 << (j - 1))][j - 1]);
}
static inline int query(int l, int r) {
    int len = lg[r - l + 1];
    return max(pre[l][len], pre[r - (1 << len) + 1][len]);
}

static inline void solve() {
    cin >> n;
    for (int i = 1; i <= cnt; ++i)
        d[i] = {0, 0, 0};
    cnt = dfn_clock = 0;
    for (int i = 1; i <= n; ++i) {
        vec[i].clear();
        rt[i] = son[i] = top[i] = 0;
        for (int j = 0; j <= 18; ++j)
            st[i][j] = pre[i][j] = 0;
    }
    for (int i = 1; i < n; ++i) {
        int u, v;
        cin >> u >> v;
        vec[u].push_back(v);
        vec[v].push_back(u);
    }
    dfs(1, 0);
    top[1] = 1;
    dfs2(1);
    build();
    cin >> q;
    while (q--) {
        int x, k;
        cin >> x >> k;
        int u = st[x][0];
        int v = max(jump(x, k), 1);
        int ans = min(k, dep[u]);
        while (u && dep[u] > dep[v] && top[u] != top[v]) { // DSU on tree
            int w = query(1, n, rt[jump(x, dep[x] - dep[u] - 1)], rt[u]);
            ans = max(ans, dep[x] + w - 2 * dep[u]);
            if (dfn[top[u]] < dep[u])
                ans = max(ans, dep[x] + query(dfn[top[u]], dfn[u] - 1));
            u = st[top[u]][0];
        }
        if (dep[u] >= dep[v]) {
            int w = query(1, n, rt[jump(x, dep[x] - dep[u] - 1)], rt[u]);
            ans = max(ans, dep[x] + w - 2 * dep[u]);
            if (dfn[v] < dfn[u])
                ans = max(ans, dep[x] + query(dfn[v], dfn[u] - 1));
        }
        // 这是暴力的实现方式
        // int ans = 0;
        // int u = st[x][0];
        // --k;
        // while (u && k >= 0) {
        //     int w = query(1, n, rt[jump(x, dep[x] - dep[u] - 1)], rt[u]);
        //     ans = max(ans, dep[x] + w - 2 * dep[u]);
        //     u = st[u][0];
        //     --k;
        // }
        ans = max(ans, query(1, n, 0, rt[x]) - dep[x]);
        cout << ans << ' ';
    }
    cout << endl;
}

signed main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    solve();
    cout.flush();
    return 0;
}
posted @ 2024-10-25 08:19  bluewindde  阅读(73)  评论(0编辑  收藏  举报