P5327 题解

考虑把贡献摊到每个点上计算,每个点带来的贡献实际上是经过它的路径并大小,算完求和之后在除以 \(2\) 就得到了答案。

考虑怎么计算路径并大小。

考虑这样一个办法,将所有路径的起始点和终点按照 DFS 序排序,相邻两点(包括第一个会最后一个点)在树上的距离之和便是其路径并大小的两倍。原理的话便是路径并大小等价于包含所有路径起始点的最小联通生成树。

考虑树上点差分,然后用线段树储存 DFS 序为 \(x\) 的点是否在子树中,维护节点内最大和最小的存在的 DFS 序就可以在通过一次求 LCA 合并两个子节点的信息。

那么最后一步通过线段树合并将子树的信息合并到父亲即可。

时间复杂度 \(O(n \log^2 n)\) 空间复杂度 \(O(n \log n)\),也可以通过写压缩 01Trie 或者直接维护线段树叶子节点的方法做到线性空间,但没什么必要。

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn = 1e5 + 114;
int fa[maxn][19], lg[maxn];
int dfn[maxn], dfncnt, dep[maxn];
int Node[maxn];
int n, m;
vector<int> edge[maxn];
void dfs1(int u, int father) {
    dep[u] = dep[father] + 1;
    dfn[u] = ++dfncnt;
    Node[dfncnt] = u;
    fa[u][0] = father;

    for (int i = 1; i <= 17; i++)
        fa[u][i] = fa[fa[u][i - 1]][i - 1];

    for (int v : edge[u]) {
        if (v == father)
            continue;

        dfs1(v, u);
    }
}
int LCA(int u, int v) {
    if (dep[u] < dep[v])
        swap(u, v);

    while (dep[u] > dep[v]) {
        u = fa[u][lg[dep[u] - dep[v]]];
    }

    if (u == v)
        return u;

    for (int i = 17; i >= 0; i--) {
        if (fa[u][i] != fa[v][i]) {
            u = fa[u][i], v = fa[v][i];
        }
    }

    return fa[u][0];
}
int dist(int x, int y) {
    return dep[x] + dep[y] - 2 * dep[LCA(x, y)];
}
#define ls(cur)(tr[cur].ls)
#define rs(cur)(tr[cur].rs)
int tot;
struct Segment_tree {
    int ls, rs;
    int mi, mx, sum, cnt;
} tr[maxn * 40];
int root[maxn];
void pushup(int cur) {
    tr[cur].cnt = tr[ls(cur)].cnt + tr[rs(cur)].cnt;

    if (tr[ls(cur)].cnt == 0 && tr[rs(cur)].cnt == 0)
        cur = 0;
    else if (tr[ls(cur)].cnt == 0)
        tr[cur].sum = tr[rs(cur)].sum, tr[cur].mi = tr[rs(cur)].mi, tr[cur].mx = tr[rs(cur)].mx;
    else if (tr[rs(cur)].cnt == 0)
        tr[cur].sum = tr[ls(cur)].sum, tr[cur].mi = tr[ls(cur)].mi, tr[cur].mx = tr[ls(cur)].mx;
    else {
        tr[cur].sum = tr[ls(cur)].sum + tr[rs(cur)].sum + dist(Node[tr[ls(cur)].mx], Node[tr[rs(cur)].mi]);
        tr[cur].mi = tr[ls(cur)].mi;
        tr[cur].mx = tr[rs(cur)].mx;
    }
}
void update(int &cur, int lt, int rt, int pos, int v) {
    if (pos < lt || pos > rt)
        return ;

    if (cur == 0)
        cur = ++tot;

    if (lt == rt && lt == pos) {
        tr[cur].cnt += v;
        tr[cur].mi = tr[cur].mx = lt;
        tr[cur].sum = 0;
        return ;
    }

    int mid = (lt + rt) >> 1;
    update(ls(cur), lt, mid, pos, v);
    update(rs(cur), mid + 1, rt, pos, v);
    pushup(cur);
}
int merge(int a, int b, int lt, int rt) {
    if (a == 0 || b == 0)
        return a + b;

    if (lt == rt) {
        tr[a].cnt += tr[b].cnt;
        tr[a].mi = tr[a].mx = lt;
        tr[a].sum = 0;
        return a;
    }

    int mid = (lt + rt) >> 1;
    tr[a].ls = merge(tr[a].ls, tr[b].ls, lt, mid);
    tr[a].rs = merge(tr[a].rs, tr[b].rs, mid + 1, rt);
    pushup(a);
    return a;
}
vector<int> Ins[maxn], Del[maxn];
int answer;
void dfs2(int u, int father) {
    for (int v : edge[u]) {
        if (v == father)
            continue;

        dfs2(v, u);
        root[u] = merge(root[u], root[v], 1, n);
    }

    for (int x : Ins[u])
        update(root[u], 1, n, x, 1);

    for (int x : Del[u])
        update(root[u], 1, n, x, -1);

    answer += (tr[root[u]].sum + dist(Node[tr[root[u]].mi], Node[tr[root[u]].mx])) / 2;
}
signed main() {
    cin >> n >> m;
    lg[1] = 0;

    for (int i = 2; i <= n; i++)
        lg[i] = lg[i / 2] + 1;

    for (int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        edge[u].push_back(v);
        edge[v].push_back(u);
    }

    dep[0] = -1;
    dfs1(1, 0);

    for (int i = 1; i <= m; i++) {
        int u, v;
        cin >> u >> v;

        if (u == v)
            continue;

        Ins[u].push_back(dfn[u]);
        Ins[u].push_back(dfn[v]);
        Ins[v].push_back(dfn[v]);
        Ins[v].push_back(dfn[u]);
        Del[LCA(u, v)].push_back(dfn[u]);
        Del[LCA(u, v)].push_back(dfn[v]);
        Del[fa[LCA(u, v)][0]].push_back(dfn[u]);
        Del[fa[LCA(u, v)][0]].push_back(dfn[v]);
    }

    dfs2(1, 0);
    cout << answer / 2;
}
posted @ 2024-02-27 18:11  ChiFAN鸭  阅读(12)  评论(0编辑  收藏  举报