题解【UR #20】跳蚤电话
首先转化题意,变成有多少种删除点的顺序能够将所有点删完。
于是我们可以做树形dp,设 \(f_i\) 表示 \(i\) 的子树内删完的排列数。
但是这样因为排列的原因,两个子树之间不独立,不好转移。
这个时候我们可以算随机排列能删完的概率,最后乘上 \((n-1)!\) 就是答案(因为 \(1\) 不能删除)。
于是设 \(f_i\) 表示 \(i\) 的子树按随机顺序删点,能删完的概率。
答案显然是 \(\prod_{u\in{son(1)}} f_u \times (n-1)!\)
考虑转移,对于 \(i\) 的子树,我们钦定他最后一个被删除的点,设这个点为 \(j\)。
若 \(j=i\),则贡献显然是 \(\prod_{u\in son(i)}f_u\) ,因为只要是 \(i\) 在最后一个的合法排列都可以。
若 \(j\ne i\),假设 \(i\) 到 \(j\) 路径上的点为 \(a_1,a_2,\dots,a_k\),那么删除每个 \(a_l\)时,他的除了 \(a_l+1\) 的所有子树都应该已经删除完,所以贡献是\(\prod_{l=1}^{k} \frac{1}{siz_{a_l}-siz_{a_{l+1}}}\prod_{u\in son(a_l),u\ne a_{l+1}}f_u\)。
当然两种情况都要乘一个 \(\frac{1}{siz_i}\)
这样做的复杂度是 \(O(n^2)\),不能通过。
考虑优化,我们注意到如果我们钦定的点是 \(i\) 的儿子 \(u\) 子树中一点 \(j\),那么 \(j\) 对 \(i\) 的贡献只是对 \(u\) 的贡献前边加上一个 \(i\)。
于是我们设 \(g_i=f_i\times siz_i\),有 \(g_i=\prod_{u\in son(i)}f_i+\sum_{u\in son(i)}g_u \frac{1}{siz_i-siz_u}\prod_{v\in son(i),u\ne v}f_v,f_i=g_i\times\frac{1}{siz_i}\)
每次算出所有儿子的 \(f\) 的乘积,然后撤销每个儿子的贡献算答案即可,这样复杂度是 \(O(n\log n)\),因为还要求逆元。
当然也可以维护一个前缀后缀拼起来,复杂度 \(O(n)\),只不过我没写。
\(\sf{Code}\)
#include<bits/stdc++.h>
#define N 2001001
#define MAX 2001
using namespace std;
typedef long long ll;
typedef double db;
const ll mod=998244353,inf=1e18,inv2=(mod+1)/2;
inline void read(ll &ret)
{
ret=0;char c=getchar();bool pd=false;
while(!isdigit(c)){pd|=c=='-';c=getchar();}
while(isdigit(c)){ret=(ret<<1)+(ret<<3)+(c&15);c=getchar();}
ret=pd?-ret:ret;
return;
}
ll n,x,y;
ll f[N],g[N];
ll siz[N];
vector<ll>v[N];
inline ll qpow(ll a,ll b)
{
ll ret=1;
while(b)
{
if(b&1)
ret*=a,ret%=mod;
b>>=1;
a*=a;
a%=mod;
}
return ret;
}
inline void dfs(ll ver,ll fa)
{
siz[ver]=1;
g[ver]=1;
ll res=1;
for(int i=0;i<v[ver].size();i++)
{
ll to=v[ver][i];
if(to==fa)
continue;
dfs(to,ver);
siz[ver]+=siz[to];
g[ver]*=f[to];
g[ver]%=mod;
res*=f[to];
res%=mod;
}
for(int i=0;i<v[ver].size();i++)
{
ll to=v[ver][i];
if(to==fa)
continue;
g[ver]+=g[to]*res%mod*qpow(f[to],mod-2)%mod*qpow(siz[ver]-siz[to],mod-2)%mod;
if(g[ver]>=mod)
g[ver]-=mod;
}
f[ver]=g[ver]*qpow(siz[ver],mod-2)%mod;
return;
}
signed main()
{
read(n);
for(int i=1;i<n;i++)
{
read(x);
read(y);
v[x].push_back(y);
v[y].push_back(x);
}
dfs(1,0);
ll ans=1;
for(int i=0;i<v[1].size();i++)
ans*=f[v[1][i]],ans%=mod;
ll tmp=1;
for(int i=2;i<n;i++)
tmp*=i,tmp%=mod;
printf("%lld\n",ans*tmp%mod);
exit(0);
}