LGP5450题解

萌萌多项式。

先考虑有根树指定根节点的拓扑排序方案数。

\(dp[u]\)\(u\) 为根子树的方案数,容易发现有

\[dp[u]=(\sum_{v\in son(u)}siz[v])!\prod_{v\in son(u)}\frac{1}{siz[v]!}=(siz[u]-1)!\prod_{v\in son(u)}\frac{1}{siz[v]!} \]

稍微推导一下可以得到 \(f[rt]\) 是:

\[n\times n!\prod_{i=1}^{n}\frac{1}{siz[i]} \]

考虑原问题。

枚举 \((a,b)\) 路径上的一条断边,表示一边全部是被 \(a\) 扩展成红色节点的,另一边全是被 \(b\) 扩展的。

容易发现非 \((a,b)\) 路径上节点的 \(siz\) 都是固定的,是常数。

于是只需要考虑 \((a,b)\) 路径上节点的 \(siz\) 即可。

\(a_i\) 表示断掉 \((a,b)\) 路径上所有边后,原本路径上第 \(i\) 个节点所在的连通块大小,\(S\)\(a\) 的前缀和。

可以随手写出柿子答案是:

\[\sum_{i=1}^{n-1}\frac{1}{\prod_{j=1}^{i}(S_i-S_{j-1})\prod_{j=i+1}^{n}(S_j-S_i)} \]

\(f_i=\lim_{x\to S_i}\frac{\prod_{j=1}^{n}(x-S_j)}{x-S_i}\),不难发现答案是:

\[\sum_{i=1}^{n-1}\frac{1}{S_if_i(-1)^{n-i}} \]

根据洛必达法则,设:

\[F(x)=\prod_{i=1}^{n}(x-S_i) \]

有:

\[f_i=G'(S_i) \]

于是写个分治乘再写个多点求值就做完了,复杂度 \(O(n\log^2n)\)

但是注意到实际上算重了。

因为有一个节点是最后被计算到的,将它归为左部分和右部分各算了一次。(除了端点)

所以直接把两个端点分别作为黑点的方案算出来,加上之后除以 \(2\) 即可。复杂度不变。

#include<algorithm>
#include<cstdio>
#include<cctype>
#include<vector>
#define IMP(lim,act) for(int qwq=(lim),i=0;i^qwq;++i)act
const int M=3e5+5,mod=998244353;
int Inv[M],buf[M<<2];int*now=buf,*w[23];
inline void swap(int&a,int&b){
	int c=a;a=b;b=c;
}
inline int Add(const int&a,const int&b){
	return a+b>=mod?a+b-mod:a+b;
}
inline int Del(const int&a,const int&b){
	return b>a?a-b+mod:a-b;
}
inline int max(const int&a,const int&b){
	return a>b?a:b;
}
inline void write(int n){
	static char s[10];int top(0);while(s[++top]=n%10^48,n/=10);while(putchar(s[top--]),top);
}
inline int read(){
	int n(0);char s;while(!isdigit(s=getchar()));while(n=n*10+(s&15),isdigit(s=getchar()));return n;
}
inline int pow(int a,int b=mod-2){
	int ans(1);for(;b;b>>=1,a=1ull*a*a%mod)if(b&1)ans=1ull*ans*a%mod;return ans;
}
inline int Getlen(const int&n){
	int len(0);while((1<<len)<n)++len;return len;
}
inline void init(const int&n){
	const int&m=Getlen(n);w[m]=now;now+=1<<m;Inv[1]=1;
	for(int i=2;i<n;++i)Inv[i]=1ull*(mod-mod/i)*Inv[mod%i]%mod;
	w[m][0]=1;w[m][1]=pow(3,mod-1>>m+1);for(int i=2;i<(1<<m);++i)w[m][i]=1ull*w[m][i-1]*w[m][1]%mod;
	for(int k=m-1;k>=0&&(w[k]=now,now+=1<<k);--k)IMP(1<<k,w[k][i]=w[k+1][i<<1]);
}
struct Poly{
	std::vector<int>F;
	Poly(const Poly&G){F=G.F;}
	Poly(const std::vector<int>G){F=G;}
	Poly(const int&x=0){if(x)F=std::vector<int>(x);}
	inline Poly&resize(const int&len){
		F.resize(len);return*this;
	}
	inline int size()const{
		return F.size();
	}
	inline int&operator[](const int&id){
		return F[id];
	}
	inline void push_back(const int&x){
		F.push_back(x);
	}
	inline Poly&reverse(){
		std::reverse(F.begin(),F.end());return*this;
	}
	inline Poly operator>>(const int&x){
		Poly G;IMP(F.size()-x,G.push_back(F[i+x]));return G;
	}
	inline Poly operator<<(const int&x){
		Poly G;G.resize(x);IMP(F.size(),G.push_back(F[i]));return G;
	}
	inline int operator()(const int&x){
		int y(1),ans(0);IMP(F.size(),ans=(ans+1ull*F[i]*y)%mod),y=1ull*y*x%mod;return ans;
	}
	inline void px(Poly G){
		F.resize(max(F.size(),G.size()));G.resize(F.size());
		for(int i(0);i^F.size();++i)F[i]=1ull*F[i]*G[i]%mod;
	}
	inline Poly&Der(){
		for(int i(1);i^F.size();++i)F[i-1]=1ull*F[i]*i%mod;F.pop_back();return*this;
	}
	inline Poly&Int(){
		F.push_back(0);
		for(int i(F.size()-1);i;--i)F[i]=1ull*F[i-1]*::Inv[i]%mod;F[0]=0;return*this;
	}
	inline void DFT(const int&M){
		int i,k,d,x,y,len,*W,*L,*R;F.resize(1<<M);
		for(len=F.size()>>1,d=M-1;len;--d,len>>=1)for(k=0;k^F.size();k+=len<<1){
			W=w[d];L=&F[k];R=&F[k|len];IMP(len,(x=*L,y=*R)),*L++=Add(x,y),*R++=1ull**W++*Del(x,y)%mod;
		}
	}
	inline void IDFT(const int&M){
		int i,k,d,x,y,len,*W,*L,*R;F.resize(1<<M);
		for(len=1,d=0;len^F.size();len<<=1,++d)for(k=0;k^F.size();k+=len<<1){
			W=w[d];L=&F[k];R=&F[k|len];IMP(len,(x=*L,y=1ull**W++**R%mod)),*L++=Add(x,y),*R++=Del(x,y);
		}
		k=::pow(F.size());IMP(F.size(),F[i]=1ull*F[i]*k%mod);for(i=1;(i<<1)<F.size();++i)swap(F[i],F[F.size()-i]);
	}
	inline Poly operator+(Poly G)const{
		Poly F=this->F;F.resize(max(F.size(),G.size()));G.resize(F.size());
		IMP(F.size(),F[i]=Add(F[i],G[i]));return F;
	}
	inline Poly operator-(Poly G)const{
		Poly F=this->F;F.resize(max(F.size(),G.size()));G.resize(F.size());
		IMP(F.size(),F[i]=Del(F[i],G[i]));return F;
	}
	inline Poly operator*(const int&x)const{
		Poly F=this->F;IMP(F.size(),F[i]=1ull*F[i]*x%mod);return F;
	}
	inline Poly operator*(Poly G)const{
		Poly F=*this;const int&m=F.size()+G.size()-1,&len=Getlen(m);
		F.DFT(len);G.DFT(len);F.px(G);F.IDFT(len);return F.resize(m);
	}
	inline Poly operator/(Poly G){
		Poly F=*this,sav;const int&m=F.size()-G.size()+1;
		sav.resize(m);IMP(m,sav[i]=G.size()+i<m?0:G[G.size()-m+i]);
		sav.reverse().inv();sav*=F.reverse();
		return sav.resize(m).reverse();
	}
	inline Poly operator%(Poly G){
		return(*this-*this/G*G).resize(G.size()-1);
	}
	inline Poly&inv(){
		Poly b1,b2,b3;const int&m=Getlen(F.size());if(!F.empty())b1.push_back(::pow(F[0]));
		for(int len=1;len<=m;++len){
			b3=b1*2;(b2=F).resize(1<<len);
			b1.DFT(len+1);b1.px(b1);b2.DFT(len+1);b1.px(b2);b1.IDFT(len+1);
			b1=b3-b1.resize(1<<len);
		}
		return*this=b1.resize(F.size());
	}
	inline Poly&ln(){
		const int&m=F.size()-1;Poly G=*this;return(this->Der()*=G.inv()).resize(m).Int();
	}
	inline Poly&exp(){
		Poly b1,b2,b3;const int&m=Getlen(F.size());b1.push_back(1);
		for(int len=1;len<=m;++len){
			b3=b2=b1;b2.resize(1<<len).ln();b2=(*this-b2).resize(1<<len);++b2[0];
			b2.DFT(len);b3.DFT(len);b2.px(b3);b2.IDFT(len);b1.resize(1<<len);
			IMP(1<<len-1,b1[1<<len-1|i]=b2[1<<len-1|i]);
		}
		return*this=b1.resize(F.size());
	}
	inline Poly&sqrt(){
		Poly b1,b2;const int&m=Getlen(F.size());b1.push_back(1);
		for(int len=1;len<=m;++len){
			b2=(b1*2).resize(1<<len).inv();
			b1.DFT(len);b1.px(b1);b1.IDFT(len);
			b1=((*this+b1).resize(1<<len)*b2).resize(1<<len);
		}
		return*this=b1.resize(F.size());
	}
	inline Poly&pow(const int&k){
		ln();IMP(F.size(),F[i]=1ull*F[i]*k%mod);return exp();
	}
	inline Poly&operator>>=(const int&x){
		return*this=operator>>(x);
	}
	inline Poly&operator<<=(const int&x){
		return*this=operator<<(x);
	}
	inline Poly&operator+=(const Poly&G){
		return*this=*this+G;
	}
	inline Poly&operator-=(const Poly&G){
		return*this=*this-G;
	}
	inline Poly&operator*=(const Poly&G){
		return*this=*this*G;
	}
	inline Poly&operator/=(const Poly&G){
		return*this=*this/G;
	}
	inline Poly&operator%=(const Poly&G){
		return*this=*this%G;
	}
};
inline Poly resize(Poly F,const int&n){
	return F.resize(n);
}
inline Poly reverse(Poly F){
	return F.reverse();
}
inline Poly Int(Poly F){
	return F.Int();
}
inline Poly Der(Poly F){
	return F.Der();
}
inline Poly px(Poly F,Poly G){
	return F.px(G),F;
}
inline Poly inv(Poly F){
	return F.inv();
}
inline Poly ln(Poly F){
	return F.ln();
}
inline Poly exp(Poly F){
	return F.exp();
}
inline Poly sqrt(Poly F){
	return F.sqrt();
}
inline Poly pow(Poly F,const int&k){
	return F.pow(k);
}
int n,a,b,ege,f[M],h[M],S[M],sz[M],siz[M];bool vis[M];int len,nd[M];
Poly F[M<<2];
struct Edge{
	int v,nx;
}e[M<<1];
inline void AddEdge(const int&u,const int&v){
	e[++ege]=(Edge){v,h[u]};h[u]=ege;
	e[++ege]=(Edge){u,h[v]};h[v]=ege;
}
inline bool Find(const int&u,const int&fa){
	if(u==a)return vis[nd[++len]=u]=true;
	for(int v,E=h[u];E;E=e[E].nx)if((v=e[E].v)^fa&&Find(v,u))return vis[nd[++len]=u]=true;
	return false;
}
inline void DFS(const int&u,const int&fa){
	siz[u]=1;for(int v,E=h[u];E;E=e[E].nx)if(v=e[E].v,v^fa&&!vis[v])DFS(v,u),siz[u]+=siz[v];
}
inline void Build(const int&u,const int&L,const int&R){
	if(L==R)return F[u].push_back(mod-S[L]),F[u].push_back(1);
	const int&mid=L+R>>1;Build(u<<1,L,mid);Build(u<<1|1,mid+1,R);F[u]=F[u<<1]*F[u<<1|1];
}
inline void Solve(const int&u,const int&L,const int&R,Poly H){
	if(L==R)return void(f[L]=H[0]);
	const int&mid=L+R>>1;Solve(u<<1,L,mid,H%F[u<<1]);Solve(u<<1|1,mid+1,R,H%F[u<<1|1]);
}
inline int Get(const int&u,const int&fa){
	int prod(1);sz[u]=1;for(int v,E=h[u];E;E=e[E].nx)if((v=e[E].v)^fa)prod=1ll*prod*Get(v,u)%mod,sz[u]+=sz[v];
	return 1ll*prod*sz[u]%mod;
}
signed main(){
	int sum(0),prod(1);n=read();a=read();b=read();init(n+1);
	for(int u,v,i=1;i<n;++i)u=read(),v=read(),AddEdge(u,v);Find(b,0);
	for(int i=1;i<=len;++i)DFS(nd[i],0);for(int i=1;i<=n;++i)if(!vis[i])prod=1ll*prod*siz[i]%mod;prod=pow(prod);
	if(len==1){
		for(int i=1;i<n;++i)prod=1ll*prod*i%mod;write(prod);
		return 0;
	}
	for(int i=1;i<=len;++i)S[i]=S[i-1]+siz[nd[i]];Build(1,1,len);Solve(1,1,len,F[1].Der());
	for(int i=1;i<len;++i)sum=(sum+pow(1ll*(len-i&1?mod-f[i]:f[i])*S[i]%mod))%mod;prod=1ll*prod*sum%mod;
	prod=(prod+pow(Get(a,0)))%mod;prod=(prod+pow(Get(b,0)))%mod;prod=1ll*prod*pow(2)%mod;
	for(int i=1;i<=n;++i)prod=1ll*prod*i%mod;write(prod);
}
posted @ 2022-08-24 20:42  Prean  阅读(12)  评论(0编辑  收藏  举报
var canShowAdsense=function(){return !!0};