AtCoder Beginner Contest 221 F Diameter set

洛谷传送门

AtCoder 传送门

显然,选出的每两个点都要组成一条直径。

进一步发现,设直径点数为 \(x\),如果 \(x \nmid 2\),所有直径都会在中点重合,否则会在连接两个中点的边重合。简单证一下,如果有两条直径不在中点或中边重合,那么:

  • 它们不可能不重合,要不然就不会成为直径了;
  • 它们在除了中点和中边的地方重合,此时长的两端构成了一条新的直径。

然后考虑计数。

  • 如果重合部分是点 \(u\),答案为 \(\prod\limits_{v, (u,v) \in E} (f_v + 1) - a - 1\),其中 \(f_v\)\(u\) 为根时 \(v\) 的子树中有多少个点与 \(u\) 的距离为 \(\frac{x-1}{2}\)\(a\) 为整棵树中有多少个点与 \(u\) 的距离为 \(\frac{x-1}{2}\)。大概意思就是每个点可选可不选,最后减去 \(|S| \le 1\)\(S\)

  • 如果重合部分是边 \((u,v)\),答案为 \(f_u \times g_v\),其中 \(f_u\) 为以 \(v\) 为根时 \(u\) 的子树中有多少个点到 \(u\) 距离为 \(\frac{x}{2} - 1\)\(g_v\) 为以 \(u\) 为根时 \(v\) 的子树中有多少个点到 \(v\) 距离为 \(\frac{x}{2} - 1\)

时间复杂度 \(O(n)\)

code
// Problem: F - Diameter set
// Contest: AtCoder - AtCoder Beginner Contest 221
// URL: https://atcoder.jp/contests/abc221/tasks/abc221_f
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))

using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef double db;
typedef long double ldb;
typedef pair<ll, ll> pii;

const int maxn = 200100;
const ll mod = 998244353;

ll n, head[maxn], len;
ll point, maxd, fa[maxn], D, cnt;

struct edge {
	int to, next;
} edges[maxn << 1];

inline void add_edge(int u, int v) {
	edges[++len].to = v;
	edges[len].next = head[u];
	head[u] = len;
}

void dfs(int u, int f, int d) {
	if (d > maxd) {
		maxd = d;
		point = u;
	}
	fa[u] = f;
	for (int i = head[u]; i; i = edges[i].next) {
		int v = edges[i].to;
		if (v == f) {
			continue;
		}
		dfs(v, u, d + 1);
	}
}

void dfs2(int u, int f, int d) {
	if (d == D) {
		++cnt;
	}
	for (int i = head[u]; i; i = edges[i].next) {
		int v = edges[i].to;
		if (v == f) {
			continue;
		}
		dfs2(v, u, d + 1);
	}
}

void solve() {
	scanf("%lld", &n);
	for (int i = 1, u, v; i < n; ++i) {
		scanf("%d%d", &u, &v);
		add_edge(u, v);
		add_edge(v, u);
	}
	dfs(1, -1, 1);
	int S = point;
	maxd = 0;
	dfs(S, -1, 1);
	int T = point;
	if (maxd & 1) {
		int x = T;
		for (int _ = 0; _ < maxd / 2; ++_) {
			x = fa[x];
		}
		ll ans = 1;
		for (int i = head[x]; i; i = edges[i].next) {
			int u = edges[i].to;
			cnt = 0;
			D = maxd / 2 - 1;
			dfs2(u, x, 0);
			ans = ans * (cnt + 1) % mod;
		}
		D = maxd / 2;
		cnt = 0;
		dfs2(x, -1, 0);
		ans -= cnt + 1;
		printf("%lld\n", ans);
	} else {
		int x = T;
		for (int _ = 0; _ < maxd / 2 - 1; ++_) {
			x = fa[x];
		}
		int y = fa[x];
		D = maxd / 2 - 1;
		cnt = 0;
		dfs2(y, x, 0);
		ll ans = cnt;
		cnt = 0;
		dfs2(x, y, 0);
		ans = ans * cnt % mod;
		printf("%lld\n", ans);
	}
}

int main() {
	int T = 1;
	// scanf("%d", &T);
	while (T--) {
		solve();
	}
	return 0;
}

posted @ 2023-05-10 21:57  zltzlt  阅读(39)  评论(0编辑  收藏  举报