【题解】P5666 [CSP-S2019] 树的重心
感觉对重心的理解更直观了一点。
题意
求一棵树上删去每一条边后两侧子树重心的编号和。
\(n \leq 3 \times 10^5\)
思路
神奇的清真树论。
首先这里有一步很妙的操作:把整棵树的重心 \(rt\) 设为根。
答案可以转化成考虑每个结点的贡献,即询问有多少条边删去以后可以使 \(rt\) 以外的结点 \(x\) 成为重心。
显然删去的这条边不能在 \(x\) 的子树内。
设 \(s_u\) 为结点 \(u\) 的子树大小,\(g_u = \max\limits_{v \in son(u)} s_v\),其中 \(son(u)\) 表示结点 \(u\) 的所有子结点。根据重心的定义,令删去这条边以后另一侧的子树大小为 \(S\),则 \(S\) 应该满足:
\(2(n - S - s_x) \leq n - S, 2 g_x \leq n - S\)
如果不考虑在子树外的限制,这里可以在 dfs 的同时用树状数组维护 \(S\) 的取值个数,问题转化成单点加区间求和。
对于子树外的限制。可以另外维护一个树状数组记录下所有经过的结点中 \(S\) 的取值,这样回溯的时候容斥一下就行。
考虑重心的贡献。设 \(u, v\) 分别是 \(rt\) 的子结点中 \(s\) 最大和次大的子结点。如果删去的边在 \(u\) 的子树内,则要满足 \(2 s_v \leq n - S\),反之要满足 \(2 s_u \leq n - S\),dfs 的时候顺便查询一下就行。
时间复杂度 \(O(n \log n)\)
代码
#include <cstdio>
#include <cstring>
#include <vector>
#include <iostream>
using namespace std;
typedef long long ll;
const int maxn = 3e5 + 5;
const int maxm = 6e5 + 5;
const int inf = 0x3f3f3f3f;
int t, n, rt, u, v;
int head[maxn], sz[maxn], mx[maxn];
bool vis[maxn];
ll ans;
vector<int> g[maxn];
struct node
{
int to, nxt;
} edge[maxm];
struct BIT
{
int c[maxn];
void clear() { memset(c, 0, (n + 2) * sizeof(int)); }
int lowbit(int x) { return x & (-x); }
void update(int p, int w) { p++; for (int i = p; i <= n + 1; i += lowbit(i)) c[i] += w; }
int query(int p)
{
int res = 0;
p++;
for (int i = p; i; i -= lowbit(i)) res += c[i];
return res;
}
} c1, c2;
void dfs1(int x, int f)
{
bool is_rt = true;
sz[x] = 1, mx[x] = 0;
for (int y : g[x])
{
if (y == f) continue;
dfs1(y, x);
sz[x] += sz[y];
mx[x] = max(mx[x], sz[y]);
if (sz[y] > (n >> 1)) is_rt = false;
}
if (n - sz[x] > (n >> 1)) is_rt = false;
if (is_rt) rt = x;
}
void dfs2(int x, int f)
{
c1.update(sz[f], -1);
c1.update(n - sz[x], 1);
vis[x] |= vis[f];
if (x != rt)
{
ans += 1ll * x * c1.query(n - 2 * mx[x]);
ans -= 1ll * x * c1.query(n - 2 * sz[x] - 1);
ans += 1ll * x * c2.query(n - 2 * mx[x]);
ans -= 1ll * x * c2.query(n - 2 * sz[x] - 1);
ans += 1ll * rt * (int)(sz[x] <= n - 2 * sz[vis[x] ? v : u]);
}
c2.update(sz[x], 1);
for (int y : g[x])
{
if (y == f) continue;
dfs2(y, x);
}
c1.update(sz[f], 1);
c1.update(n - sz[x], -1);
if (x != rt)
{
ans -= 1ll * x * c2.query(n - 2 * mx[x]);
ans += 1ll * x * c2.query(n - 2 * sz[x] - 1);
}
}
int main()
{
scanf("%d", &t);
while (t--)
{
scanf("%d", &n);
for (int i = 1; i <= n; i++) g[i].clear();
for (int i = 1, u, v; i <= n - 1; i++)
{
scanf("%d%d", &u, &v);
g[u].push_back(v);
g[v].push_back(u);
}
ans = 0ll;
dfs1(1, 0);
dfs1(rt, 0);
u = v = 0;
for (int x : g[rt])
{
if (sz[x] > sz[v]) v = x;
if (sz[v] > sz[u]) swap(u, v);
}
c1.clear(), c2.clear();
for (int i = 0; i <= n; i++) c1.update(sz[i], 1), vis[i] = false;
vis[u] = true;
dfs2(rt, 0);
printf("%lld\n", ans);
}
return 0;
}