AtCoder Grand Contest 058 F Authentic Tree DP
人生中第一道 AtCoder 问号题。
设 \(P = 998244353\)。
注意到 \(f(T)\) 的定义式中,\(\frac{1}{n}\) 大概是启示我们转成概率去做。发现若把 \(\frac{1}{n}\) 换成 \(\frac{1}{n - 1}\) 答案就是 \(1\),所以 \(\frac{1}{n}\) 大概是要转成点数之类的。
考虑把边转成点,若原树存在边 \((u, v)\),就新建点 \(p\),断开 \((u, v)\),连边 \((u, p), (p, v)\),称 \(p\) 点为边点。但是这样点数就变成 \(2n - 1\) 了。
但是!考虑再挂 \(P - 1\) 个叶子到 \(p\) 下面,点数就变成 \(n + (n - 1) \times P\)。模意义下 \(\frac{1}{n} \equiv \frac{1}{n + (n - 1)P} \pmod P\)。
我们可以把原问题转化成在新树上的这个问题:
随机生成一个排列 \(p_{1 \sim n + (n - 1)P}\),求所有边点的 \(p\) 值大于其所有邻居的 \(p\) 值的概率。
证明大概就是考虑不断取树上的最大值,取 \(n - 1\) 次,每次取边点的概率在模意义下等于 \(\frac{1}{n}\),转移式也与原题相同。
考虑给树上的边定向,从小连到大,那么就是要求每条边的起点都是边点的概率。像这样(图来自 kkio):
随便定一个根。发现有些 \(p \to u\) 的边从下连到上看起来不顺眼,考虑容斥。那么所有下连到上的边可以选择上连到下或者下连到上(断掉)。设有 \(k\) 条原本是下连到上的边,容斥系数为 \((-1)^k\)。
那么我们可以设 \(f_{u, i}\) 表示,\(u\) 的子树中以 \(u\) 为根的外向树大小模 \(P\) 意义下等于 \(i\),容斥系数乘概率之和。
对于一个边点 \(p\),在它原树上对应的边 \((u, v)\) 上统计贡献。
考虑若边 \((u, p)\) 从上到下,那 \(v\) 子树中以 \(v\) 为根的的外向树可以直接接到 \(u\) 下面,直接树形背包合并,\(f_{u, i + j} \gets -f'_{u, i} \times \frac{f_{v, j}}{j}\)。乘 \(\frac{1}{j}\) 是计入边点作为外向树的根的概率,乘 \(-1\) 是计入容斥系数。若边 \((u, p)\) 断开,那么 \(f_{u, i} \gets f'_{u, i} \times \sum\limits_{j = 1}^{sz_v} \frac{f_{v, j}}{j}\)。最后还要 \(f_{u, i} \gets \frac{f_{u, i}}{i}\),表示 \(u\) 点作为外向树的根的概率。
时间复杂度 \(O(n^2)\)。
code
// Problem: F - Authentic Tree DP
// Contest: AtCoder - AtCoder Grand Contest 058
// URL: https://atcoder.jp/contests/agc058/tasks/agc058_f
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 5050;
const ll mod = 998244353;
ll n, f[maxn][maxn], g[maxn], h[maxn], inv[maxn], sz[maxn];
vector<int> G[maxn];
void dfs(int u, int fa) {
sz[u] = 1;
f[u][1] = 1;
for (int v : G[u]) {
if (v == fa) {
continue;
}
dfs(v, u);
for (int i = 1; i <= sz[u] + sz[v]; ++i) {
h[i] = f[u][i];
f[u][i] = 0;
}
for (int i = 1; i <= sz[u]; ++i) {
for (int j = 1; j <= sz[v]; ++j) {
f[u][i + j] = (f[u][i + j] - h[i] * f[v][j] % mod * inv[j] % mod + mod) % mod;
}
}
for (int i = 1; i <= sz[u]; ++i) {
f[u][i] = (f[u][i] + h[i] * g[v] % mod) % mod;
}
sz[u] += sz[v];
}
for (int i = 1; i <= sz[u]; ++i) {
f[u][i] = f[u][i] * inv[i] % mod;
g[u] = (g[u] + f[u][i] * inv[i] % mod) % mod;
}
}
void solve() {
scanf("%lld", &n);
inv[1] = 1;
for (int i = 2; i <= n; ++i) {
inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}
for (int i = 1, u, v; i < n; ++i) {
scanf("%d%d", &u, &v);
G[u].pb(v);
G[v].pb(u);
}
dfs(1, -1);
ll ans = 0;
for (int i = 1; i <= n; ++i) {
ans = (ans + f[1][i]) % mod;
}
printf("%lld\n", ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}