2024CCPC哈尔滨 L 题解
思路
首先可以发现这个期望其实是假的,我们只需要把所有方案的答案加起来,最后除以 \((\frac{n(n-1)}{2})^2\) 即可,现在考虑如何统计所有方案的答案。
我们先考虑一条路径的方案数:假设存在一条从 \(x\) 到 \(y\) 的公共路径,其中 \(x\) 是 \(y\) 的祖先,那么小红和小蓝分别选择的路径,其中一边的端点肯定在 \(y\) 的子树内,总共的方案数为 \(siz[y]^2\)。但是这并不正确,我们不妨记 \(y\) 的儿子为 \(y_1, y_2,...,y_m\),那么如果选择的端点都落在 \(y_1\) 内,那么实际的公共路径就变为了 \(x \rightarrow y_1\),所以实际的方案数是 \(siz[y]^2- \sum_{y_i \in son_y}siz[y_i]^2\),下面我们记 \(siz2[y]\) 为 \(\sum_{y_i \in son_y}siz[y_i]^2\)。
知道这个之后,我们可以考虑树形dp,由于 \((a+1)^2=a^2+2a+1\),所以我们可以考虑维护所有路径的 \(i\) 次项和,即方案数乘路径长度的 \(i\) 次方。
记 \(f[x][0/1/2]\) 为 \(x\) 子树内所有到 \(x\) 的路径的 \(0/1/2\) 次项之和,现在考虑转移:记 \(y\) 为 \(x\) 的一个儿子,现在我们要把 \(y\) 向 \(x\) 合并,首先因为会加入 \(x-y\) 这条边,所以所有边的长度都会加一,即:
同时 \(x-y\) 这条边也要算进我们的方案数,即:
这样处理完之后直接加到 \(f[x]\) 里即可。
考虑完转移,现在我们考虑如何计算答案,同样是 \(y\) 向 \(x\) 转移的过程,我f们可以从 \(x\) 和 \(y\) 里面分别选出一条路径拼成一条完整的路径,记两条路径分别为 \(h_1,h_2\),有:
除此之外,\(y\) 中的路径也可以单独成为一条可行的路径,这种情况下 \(x\) 侧端点的方案数为
即两个端点不能都在同一个 \(x\) 的儿子的子树内,也不能都在 \(x\) 的父亲之外,这里可以自己画画图理解一下。
到这里就把这道题解决了,更多细节见代码
代码
#include <bits/stdc++.h>
using i64 = long long;
constexpr int P = 998244353;
i64 power(i64 a, i64 b)
{
i64 res = 1;
for( ; b; b >>= 1, a = a * a % P)
if(b & 1) res = res * a % P;
return res;
}
void solve()
{
int n; std::cin >> n;
std::vector<std::vector<int>> adj(n);
for(int i = 1; i < n; i++)
{
int u, v; std::cin >> u >> v;
u--, v--;
adj[u].emplace_back(v);
adj[v].emplace_back(u);
}
std::vector<i64> siz(n), siz2(n);
auto init = [&](auto init, int x, int fa) -> void
{
siz[x] = 1;
for(auto y : adj[x])
{
if(y == fa) continue;
init(init, y, x);
siz[x] += siz[y];
siz2[x] += siz[y] * siz[y];
}
};
init(init, 0, -1);
std::vector f(n, std::vector<i64>(3));
i64 ans = 0;
auto dfs = [&](auto dfs, int x, int fa) -> void
{
for(auto y : adj[x])
{
if(y == fa) continue;
dfs(dfs, y, x);
f[y][2] = (f[y][2] + 2 * f[y][1] + f[y][0]) % P;
f[y][1] = (f[y][1] + f[y][0]) % P;
for(int i = 0; i < 3; i++) f[y][i] = (f[y][i] + siz[y] * siz[y] - siz2[y]) % P;
i64 res = (f[y][0] * f[x][2] + f[y][2] * f[x][0] + 2 * f[y][1] * f[x][1]) % P;
i64 res2 = f[y][2] * ((n - siz[y]) * (n - siz[y]) - (siz2[x] - siz[y] * siz[y]) - (n - siz[x]) * (n - siz[x]));
ans = (ans + res + res2) % P;
for(int i = 0; i < 3; i++) f[x][i] = (f[x][i] + f[y][i]) % P;
}
};
dfs(dfs, 0, -1);
i64 inv = power(1LL * n * (n - 1) / 2 % P, P - 2);
ans = ans * inv % P * inv % P;
std::cout << ans << "\n";
}
int main()
{
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
int t; std::cin >> t;
while(t--) solve();
return 0;
}