Loading

【题解】P5666 [CSP-S2019] 树的重心

感觉对重心的理解更直观了一点。

题意

求一棵树上删去每一条边后两侧子树重心的编号和。

\(n \leq 3 \times 10^5\)

思路

神奇的清真树论。

首先这里有一步很妙的操作:把整棵树的重心 \(rt\) 设为根。

答案可以转化成考虑每个结点的贡献,即询问有多少条边删去以后可以使 \(rt\) 以外的结点 \(x\) 成为重心。

显然删去的这条边不能在 \(x\) 的子树内。

\(s_u\) 为结点 \(u\) 的子树大小,\(g_u = \max\limits_{v \in son(u)} s_v\),其中 \(son(u)\) 表示结点 \(u\) 的所有子结点。根据重心的定义,令删去这条边以后另一侧的子树大小为 \(S\),则 \(S\) 应该满足:

\(2(n - S - s_x) \leq n - S, 2 g_x \leq n - S\)

如果不考虑在子树外的限制,这里可以在 dfs 的同时用树状数组维护 \(S\) 的取值个数,问题转化成单点加区间求和。

对于子树外的限制。可以另外维护一个树状数组记录下所有经过的结点中 \(S\) 的取值,这样回溯的时候容斥一下就行。

考虑重心的贡献。设 \(u, v\) 分别是 \(rt\) 的子结点中 \(s\) 最大和次大的子结点。如果删去的边在 \(u\) 的子树内,则要满足 \(2 s_v \leq n - S\),反之要满足 \(2 s_u \leq n - S\),dfs 的时候顺便查询一下就行。

时间复杂度 \(O(n \log n)\)

代码

#include <cstdio>
#include <cstring>
#include <vector>
#include <iostream>
using namespace std;

typedef long long ll;

const int maxn = 3e5 + 5;
const int maxm = 6e5 + 5;
const int inf = 0x3f3f3f3f;

int t, n, rt, u, v;
int head[maxn], sz[maxn], mx[maxn];
bool vis[maxn];
ll ans;
vector<int> g[maxn];

struct node
{
    int to, nxt;
} edge[maxm];

struct BIT
{
    int c[maxn];

    void clear() { memset(c, 0, (n + 2) * sizeof(int)); }

    int lowbit(int x) { return x & (-x); }

    void update(int p, int w) { p++; for (int i = p; i <= n + 1; i += lowbit(i)) c[i] += w; }

    int query(int p)
    {
        int res = 0;
        p++;
        for (int i = p; i; i -= lowbit(i)) res += c[i];
        return res;
    }
} c1, c2;

void dfs1(int x, int f)
{
    bool is_rt = true;
    sz[x] = 1, mx[x] = 0;
    for (int y : g[x])
    {
        if (y == f) continue;
        dfs1(y, x);
        sz[x] += sz[y];
        mx[x] = max(mx[x], sz[y]);
        if (sz[y] > (n >> 1)) is_rt = false;
    }
    if (n - sz[x] > (n >> 1)) is_rt = false;
    if (is_rt) rt = x;
}

void dfs2(int x, int f)
{
    c1.update(sz[f], -1);
    c1.update(n - sz[x], 1);
    vis[x] |= vis[f];
    if (x != rt)
    {
        ans += 1ll * x * c1.query(n - 2 * mx[x]);
        ans -= 1ll * x * c1.query(n - 2 * sz[x] - 1);
        ans += 1ll * x * c2.query(n - 2 * mx[x]);
        ans -= 1ll * x * c2.query(n - 2 * sz[x] - 1);
        ans += 1ll * rt * (int)(sz[x] <= n - 2 * sz[vis[x] ? v : u]);
    }
    c2.update(sz[x], 1);
    for (int y : g[x])
    {
        if (y == f) continue;
        dfs2(y, x);
    }
    c1.update(sz[f], 1);
    c1.update(n - sz[x], -1);
    if (x != rt)
    {
        ans -= 1ll * x * c2.query(n - 2 * mx[x]);
        ans += 1ll * x * c2.query(n - 2 * sz[x] - 1);
    }
}

int main()
{
    scanf("%d", &t);
    while (t--)
    {
        scanf("%d", &n);
        for (int i = 1; i <= n; i++) g[i].clear();
        for (int i = 1, u, v; i <= n - 1; i++)
        {
            scanf("%d%d", &u, &v);
            g[u].push_back(v);
            g[v].push_back(u);
        }
        ans = 0ll;
        dfs1(1, 0);
        dfs1(rt, 0);
        u = v = 0;
        for (int x : g[rt])
        {
            if (sz[x] > sz[v]) v = x;
            if (sz[v] > sz[u]) swap(u, v);
        }
        c1.clear(), c2.clear();
        for (int i = 0; i <= n; i++) c1.update(sz[i], 1), vis[i] = false;
        vis[u] = true;
        dfs2(rt, 0);
        printf("%lld\n", ans);
    }
    return 0;
}
posted @ 2023-01-08 16:00  kymru  阅读(169)  评论(0编辑  收藏  举报