AtCoder Beginner Contest 221 F Diameter set
显然,选出的每两个点都要组成一条直径。
进一步发现,设直径点数为 \(x\),如果 \(x \nmid 2\),所有直径都会在中点重合,否则会在连接两个中点的边重合。简单证一下,如果有两条直径不在中点或中边重合,那么:
- 它们不可能不重合,要不然就不会成为直径了;
- 它们在除了中点和中边的地方重合,此时长的两端构成了一条新的直径。
然后考虑计数。
-
如果重合部分是点 \(u\),答案为 \(\prod\limits_{v, (u,v) \in E} (f_v + 1) - a - 1\),其中 \(f_v\) 为 \(u\) 为根时 \(v\) 的子树中有多少个点与 \(u\) 的距离为 \(\frac{x-1}{2}\),\(a\) 为整棵树中有多少个点与 \(u\) 的距离为 \(\frac{x-1}{2}\)。大概意思就是每个点可选可不选,最后减去 \(|S| \le 1\) 的 \(S\)。
-
如果重合部分是边 \((u,v)\),答案为 \(f_u \times g_v\),其中 \(f_u\) 为以 \(v\) 为根时 \(u\) 的子树中有多少个点到 \(u\) 距离为 \(\frac{x}{2} - 1\),\(g_v\) 为以 \(u\) 为根时 \(v\) 的子树中有多少个点到 \(v\) 距离为 \(\frac{x}{2} - 1\)。
时间复杂度 \(O(n)\)。
code
// Problem: F - Diameter set
// Contest: AtCoder - AtCoder Beginner Contest 221
// URL: https://atcoder.jp/contests/abc221/tasks/abc221_f
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 200100;
const ll mod = 998244353;
ll n, head[maxn], len;
ll point, maxd, fa[maxn], D, cnt;
struct edge {
int to, next;
} edges[maxn << 1];
inline void add_edge(int u, int v) {
edges[++len].to = v;
edges[len].next = head[u];
head[u] = len;
}
void dfs(int u, int f, int d) {
if (d > maxd) {
maxd = d;
point = u;
}
fa[u] = f;
for (int i = head[u]; i; i = edges[i].next) {
int v = edges[i].to;
if (v == f) {
continue;
}
dfs(v, u, d + 1);
}
}
void dfs2(int u, int f, int d) {
if (d == D) {
++cnt;
}
for (int i = head[u]; i; i = edges[i].next) {
int v = edges[i].to;
if (v == f) {
continue;
}
dfs2(v, u, d + 1);
}
}
void solve() {
scanf("%lld", &n);
for (int i = 1, u, v; i < n; ++i) {
scanf("%d%d", &u, &v);
add_edge(u, v);
add_edge(v, u);
}
dfs(1, -1, 1);
int S = point;
maxd = 0;
dfs(S, -1, 1);
int T = point;
if (maxd & 1) {
int x = T;
for (int _ = 0; _ < maxd / 2; ++_) {
x = fa[x];
}
ll ans = 1;
for (int i = head[x]; i; i = edges[i].next) {
int u = edges[i].to;
cnt = 0;
D = maxd / 2 - 1;
dfs2(u, x, 0);
ans = ans * (cnt + 1) % mod;
}
D = maxd / 2;
cnt = 0;
dfs2(x, -1, 0);
ans -= cnt + 1;
printf("%lld\n", ans);
} else {
int x = T;
for (int _ = 0; _ < maxd / 2 - 1; ++_) {
x = fa[x];
}
int y = fa[x];
D = maxd / 2 - 1;
cnt = 0;
dfs2(y, x, 0);
ll ans = cnt;
cnt = 0;
dfs2(x, y, 0);
ans = ans * cnt % mod;
printf("%lld\n", ans);
}
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}