题解【UR #20】跳蚤电话

link

首先转化题意,变成有多少种删除点的顺序能够将所有点删完。

于是我们可以做树形dp,设 \(f_i\) 表示 \(i\) 的子树内删完的排列数。

但是这样因为排列的原因,两个子树之间不独立,不好转移。

这个时候我们可以算随机排列能删完的概率,最后乘上 \((n-1)!\) 就是答案(因为 \(1\) 不能删除)。

于是设 \(f_i\) 表示 \(i\) 的子树按随机顺序删点,能删完的概率。

答案显然是 \(\prod_{u\in{son(1)}} f_u \times (n-1)!\)

考虑转移,对于 \(i\) 的子树,我们钦定他最后一个被删除的点,设这个点为 \(j\)

\(j=i\),则贡献显然是 \(\prod_{u\in son(i)}f_u\) ,因为只要是 \(i\) 在最后一个的合法排列都可以。

\(j\ne i\),假设 \(i\)\(j\) 路径上的点为 \(a_1,a_2,\dots,a_k\),那么删除每个 \(a_l\)时,他的除了 \(a_l+1\) 的所有子树都应该已经删除完,所以贡献是\(\prod_{l=1}^{k} \frac{1}{siz_{a_l}-siz_{a_{l+1}}}\prod_{u\in son(a_l),u\ne a_{l+1}}f_u\)

当然两种情况都要乘一个 \(\frac{1}{siz_i}\)

这样做的复杂度是 \(O(n^2)\),不能通过。

考虑优化,我们注意到如果我们钦定的点是 \(i\) 的儿子 \(u\) 子树中一点 \(j\),那么 \(j\)\(i\) 的贡献只是对 \(u\) 的贡献前边加上一个 \(i\)

于是我们设 \(g_i=f_i\times siz_i\),有 \(g_i=\prod_{u\in son(i)}f_i+\sum_{u\in son(i)}g_u \frac{1}{siz_i-siz_u}\prod_{v\in son(i),u\ne v}f_v,f_i=g_i\times\frac{1}{siz_i}\)

每次算出所有儿子的 \(f\) 的乘积,然后撤销每个儿子的贡献算答案即可,这样复杂度是 \(O(n\log n)\),因为还要求逆元。

当然也可以维护一个前缀后缀拼起来,复杂度 \(O(n)\)只不过我没写。

\(\sf{Code}\)

#include<bits/stdc++.h>
#define N 2001001
#define MAX 2001
using namespace std;
typedef long long ll;
typedef double db;
const ll mod=998244353,inf=1e18,inv2=(mod+1)/2;
inline void read(ll &ret)
{
	ret=0;char c=getchar();bool pd=false;
	while(!isdigit(c)){pd|=c=='-';c=getchar();}
	while(isdigit(c)){ret=(ret<<1)+(ret<<3)+(c&15);c=getchar();}
	ret=pd?-ret:ret;
	return;
}
ll n,x,y;
ll f[N],g[N];
ll siz[N];
vector<ll>v[N];
inline ll qpow(ll a,ll b)
{
	ll ret=1;
	while(b)
	{
		if(b&1)
			ret*=a,ret%=mod;
		b>>=1;
		a*=a;
		a%=mod;
	}
	return ret;
}
inline void dfs(ll ver,ll fa)
{
	siz[ver]=1;
	g[ver]=1;
	ll res=1;
	for(int i=0;i<v[ver].size();i++)
	{
		ll to=v[ver][i];
		if(to==fa)
			continue;
		dfs(to,ver);
		siz[ver]+=siz[to];
		g[ver]*=f[to];
		g[ver]%=mod;
		res*=f[to];
		res%=mod;
	}
	for(int i=0;i<v[ver].size();i++)
	{
		ll to=v[ver][i];
		if(to==fa)
			continue;
		g[ver]+=g[to]*res%mod*qpow(f[to],mod-2)%mod*qpow(siz[ver]-siz[to],mod-2)%mod;
		if(g[ver]>=mod)
			g[ver]-=mod;
	}
	f[ver]=g[ver]*qpow(siz[ver],mod-2)%mod;
	return;
}
signed main()
{
	read(n);
	for(int i=1;i<n;i++)
	{
		read(x);
		read(y);
		v[x].push_back(y);
		v[y].push_back(x);
	}
	dfs(1,0);
	ll ans=1;
	for(int i=0;i<v[1].size();i++)
		ans*=f[v[1][i]],ans%=mod;
	ll tmp=1;
	for(int i=2;i<n;i++)
		tmp*=i,tmp%=mod;
	printf("%lld\n",ans*tmp%mod);
	exit(0);
}
posted @ 2022-05-04 12:57  CelticOIer  阅读(123)  评论(0编辑  收藏  举报