Solution -「CF 1060F」Shrinking Tree
\(\mathcal{Description}\)
Link.
给定一棵 \(n\) 个点的树,反复随机选取一条边,合并其两端两点,新点编号在两端两点等概率选取。问每个点留到最后的概率。
\(n\le50\)。
\(\mathcal{Solution}\)
所有的操作方案数是 \((n-1)!\),我们可以按删边顺序看做一个长度为 \(n-1\) 的序列。对于每个点分别计算答案,把当前要算的点提为根(记为 \(r\)),我们只需要求出 \(r\) 在所有操作序列中存活的概率和除以 \((n-1)!\) 即可。
令 \(f(u,i)\) 为 \(r\) 已经走到 \(u\),\(u\) 子树内还剩下 \(i\) 条边没删(没加入删边序列),最终 \(u\)(即 \(r\))存活的概率和。显然答案为 \(f(r,n-1)\),边界 \(f(leaf,0)=1\)。
第一步,考虑儿子 \(v\) 与 \(u\) 合并。相当于需要考虑边 \((u,v)\) 在操作序列中的位置。粗略来说,若 \(r\) 没走到 \(u\),我们并不关心 \(u\) 号结点的生死;而 \(r\) 到 \(u\) 后,\(u\) 就必须存活。
定义辅助状态 \(g(u,i)\) 表示 \(u\) 子树内以及 \(u\) 的父边还剩下 \(i\) 条边没删,最终 \(u\) 存活的概率和。现在我们要计算 \(g(u,i)\)。
第一类,\((u,v)\) 保留到 \(r\) 到达 \(u\) 后再删,那么就涉及到 \(u\) 点存活的概率。于是有转移:
第二类,\((u,v)\) 在 \(r\) 到达 \(u\) 之前就删,那就很随意啦—— \(v\) 子树中已删除了 \(siz_v-1-i\) 条边,我们把 \((u,v)\) 随便插进一个位置就好,即:
上两类转移贡献之和即为最终的 \(g(u,i)\)。
考虑合并,始终记住删除的“序列意义”——保留的边(状态第二维)在删边序列的右端,其它的边在删边序列的左端。合并两个删边序列,仍需要保证这一点,那么分别用组合数合并已删除的左端序列和待删除的右端序列即可。下是 @ywy_c_asm 博客的一张图 owo(红色已删除,蓝色待删除):
答案呼之欲出啦:
两个组合数分别对应分配已删和待删的方案数。
最终,复杂度 \(\mathcal O(n^4)\) 解决了这道毒瘤 DP qwq。
\(\mathcal{Code}\)
#include <cstdio>
#include <cstring>
const int MAXN = 50;
int n, ecnt, head[MAXN + 5], siz[MAXN + 5];
double fac[MAXN + 5];
double f[MAXN + 5][MAXN + 5];
double g[MAXN + 5], h[MAXN + 5];
struct Edge { int to, nxt; } graph[MAXN * 2 + 5];
inline void link ( const int s, const int t ) {
graph[++ ecnt] = { t, head[s] };
head[s] = ecnt;
}
inline void init () {
fac[0] = 1;
for ( int i = 1; i <= n; ++ i ) fac[i] = fac[i - 1] * i;
}
inline double comb ( const int n, const int m ) {
return n < m ? 0 : fac[n] / fac[m] / fac[n - m];
}
inline void solve ( const int u, const int fa ) {
f[u][0] = siz[u] = 1;
for ( int i = head[u], v; i; i = graph[i].nxt ) {
if ( ( v = graph[i].to ) ^ fa ) {
solve ( v, u );
for ( int j = 0; j <= siz[v]; ++ j ) {
g[j] = 0;
for ( int k = 0; k < j; ++ k ) g[j] += 0.5 * f[v][k];
g[j] += ( siz[v] - j ) * f[v][j];
}
for ( int j = 0; j <= siz[v] + siz[u]; ++ j ) h[j] = 0;
for ( int j = 0; j < siz[u]; ++ j ) {
for ( int k = 0; k <= siz[v]; ++ k ) {
h[j + k] += f[u][j] * g[k] * comb ( j + k, j )
* comb ( siz[u] + siz[v] - 1 - j - k, siz[u] - 1 - j );
}
}
siz[u] += siz[v];
for ( int j = 0; j <= siz[u]; ++ j ) f[u][j] = h[j];
}
}
}
int main () {
scanf ( "%d", &n ), init ();
for ( int i = 1, u, v; i < n; ++ i ) {
scanf ( "%d %d", &u, &v );
link ( u, v ), link ( v, u );
}
for ( int i = 1; i <= n; ++ i ) {
memset ( f, 0, sizeof f );
solve ( i, 0 );
printf ( "%.12f\n", f[i][n - 1] / fac[n - 1] );
}
return 0;
}