【WC2019】数树 树形DP 多项式exp

题目大意

  有两棵 \(n\) 个点的树 \(T_1\)\(T_2\)

  你要给每个点一个权值吗,要求每个点的权值为 \([1,y]\) 内的整数。

  对于一条同时出现在两棵树上的边,这条边的两个端点的值相同。

  若 \(op=0\),则给你两棵树 \(T_1,T_2\),求方案数。

  若 \(op=1\),则给你一棵树 \(T_1\),求对于所有 \(n^{n-2}\)\(T_2\),方案数之和。

  若 \(op=2\),则求对于所有的 \(T_1,T_2\),求方案数之和。

  \(n\leq 100000\)

题解

  新建一个图 \(G\),把两棵树的公共边加到 \(G\) 中。记 \(m\) 为两棵树的公共边数量。那么答案就是 \(y^{n-m}\)

  令 \(z=y^{-1}\),那么答案就变成了 \(y^nz^m\)。也就是说,每有一条相同的边,方案的贡献就要 \(\times z\)

op=0

  这个大家都会。

op=1

\[z^m=\sum_{i=0}^m\binom{m}{i}(z-1)^i \]

  那么可以枚举一个边集 \(E\),计算有多少种生成树包含 \(E\),然后把答案加上方案数 \(\times{(z-1)}^{\lvert E\rvert}\)

  记这 \(E\) 条边形成了 \(m\) 个连通块,这些连通块的大小为 \(a_1,a_2,\ldots,a_m\),那么贡献就是

\[\begin{align} &{(z-1)}^{n-m}\sum_{\sum_{i=1}^md_i=2m-2}(m-2)!\prod_{i=1}^m\frac{a_i^{d_i}}{(d_i-1)!}\\ =&{(z-1)}^{n-m}n^{m-2}\prod_{i=1}^ma_i\\ \end{align} \]

  \(\prod_{i=1}^ma_i\) 可以看成是每个连通块内选一个点的方案数。这样就可以DP了。

  时间复杂度:\(O(n)\)

op=2

  枚举两棵树的公共边个数:

\[\begin{align} s_n&=\sum_{i=1}^{n}{(z-1)}^{n-i}\sum_{\sum_{j=1}^ia_j=n}\frac{n!}{i!}(\prod_{j=1}^i\frac{a_j^{a_j-2}}{a_j!})(n^{i-2}\prod_{j=1}^ia_j)^2\\ &=\sum_{i=1}^{n}{(z-1)}^{n-i}\frac{n!n^{2i-4}}{i!}\sum_{\sum_{j=1}^ia_j=n}\prod_{j=1}^i\frac{a_j^{a_j}}{a_j!}\\ &=\sum_{i=1}^{n}{(z-1)}^{n-i}n^{2i-4}\sum_{\sum_{j=1}^ia_j=n}\prod_{j=1}^i\binom{(\sum_{k=1}^ja_k)-1}{a_j-1}{}a_j^{a_j}\\ \end{align} \]

  记 \(f_l=\sum_{i=1}^{l}{(z-1)}^{-i}n^{2i}\sum_{\sum_{j=1}^ia_j=l}\prod_{j=1}^i\binom{(\sum_{k=1}^ja_k)-1}{a_j-1}{}a_j^{a_j}\)

  转移时枚举最后一块的大小,有:

\[f_i=\begin{cases} 1&,i=0\\ \sum_{j=1}^i\frac{(i-1)!n^2j^jf_{i-j}}{(i-j)!(j-1)!(z-1)}&,i>0 \end{cases} \]

  直接DP是 \(O(n^2)\) 的。

  记 \(g_i=\sum_{i\geq 1}\frac{n^2i^i}{(i-1)!(z-1)}\)\(F(x)\)\(f\) 的 EGF,\(G(x)\)\(g\) 的 OGF,那么

\[\begin{align} xF'(x)&=F(x)G(x)\\ \frac{F'(x)}{F(x)}&=\frac{G(x)}{x}\\ \ln F(x)&=\int \frac{G(x)}{x}\\ F(x)&=e^{\int \frac{G(x)}{x}} \end{align} \]

  直接多项式 exp 就好了。

  答案为 \((z-1)^nn^{-4}f_n\)

  时间复杂度:\(O(n\log n)\)

代码

const ll p=998244353;
ll fp(ll a,ll b)
{
	ll s=1;
	for(;b;b>>=1,a=a*a%p)
		if(b&1)
			s=s*a%p;
	return s;
}
const int N=100010;
int n,op;
ll z,_z;
ll ans;
namespace solve0
{
	map<int,int> a[N];
	void solve()
	{
		if(_z==1)
		{
			ans=1;
			return;
		}
		int x,y;
		for(int i=1;i<n;i++)
		{
			io::get(x);
			io::get(y);
			if(x>y)
				swap(x,y);
			a[x][y]++;
		}
		ans=1;
		for(int i=1;i<n;i++)
		{
			io::get(x);
			io::get(y);
			if(x>y)
				swap(x,y);
			if(a[x].count(y))
				ans=ans*z%p;
		}
	}
}
namespace solve1
{
	vector<int> g[N];
	ll f[N][2];
	void dfs(int x,int fa)
	{
		f[x][0]=f[x][1]=1;
		for(auto v:g[x])
			if(v!=fa)
			{
				dfs(v,x);
				ll s0=(f[x][0]*f[v][0]%p*z+f[x][0]*f[v][1]%p*n)%p;
				ll s1=(f[x][0]*f[v][1]%p*z+f[x][1]*f[v][0]%p*z+f[x][1]*f[v][1]%p*n)%p;
				f[x][0]=s0;
				f[x][1]=s1;
			}
	}
	void solve()
	{
		if(_z==1)
		{
			ans=fp(n,n-2);
			return;
		}
		int x,y;
		for(int i=1;i<n;i++)
		{
			io::get(x);
			io::get(y);
			g[x].push_back(y);
			g[y].push_back(x);
		}
		z--;
		dfs(1,0);
		ans=f[1][1]*fp(n,p-2)%p;
	}
}
namespace solve2
{
	const int N=270000;
	namespace ntt
	{
		const int W=262144;
		ll w[N];
		int rev[N];
		void init()
		{
			w[0]=1;
			ll s=fp(3,(p-1)/W);
			for(int i=1;i<W/2;i++)
				w[i]=w[i-1]*s%p;
		}
		void ntt(ll *a,int n,int t)
		{
			for(int i=1;i<n;i++)
			{
				rev[i]=(rev[i>>1]>>1)|(i&1?n>>1:0);
				if(rev[i]>i)
					swap(a[i],a[rev[i]]);
			}
			for(int i=2;i<=n;i<<=1)
				for(int j=0;j<n;j+=i)
					for(int k=0;k<i/2;k++)
					{
						ll u=a[j+k];
						ll v=a[j+k+i/2]*w[W/i*k];
						a[j+k]=(u+v)%p;
						a[j+k+i/2]=(u-v)%p;
					}
			if(t==-1)
			{
				reverse(a+1,a+n);
				ll inv=fp(n,p-2);
				for(int i=0;i<n;i++)
					a[i]=a[i]*inv%p;
			}
		}
		void mul(ll *a,ll *b,ll *c,int n,int m,int l)
		{
			static ll a1[N],a2[N];
			int k=1;
			while(k<=n+m)
				k<<=1;
			memset(a1,0,sizeof(a1[0])*k);
			memset(a2,0,sizeof(a2[0])*k);
			memcpy(a1,a,sizeof(a1[0])*(n+1));
			memcpy(a2,b,sizeof(a2[0])*(m+1));
			ntt::ntt(a1,k,1);
			ntt::ntt(a2,k,1);
			for(int i=0;i<k;i++)
				a1[i]=a1[i]*a2[i]%p;
			ntt::ntt(a1,k,-1);
			memcpy(c,a1,sizeof(a1[0])*(l+1));
		}
		void inv(ll *a,ll *b,int n)
		{
			if(n==1)
			{
				b[0]=fp(a[0],p-2);
				return;
			}
			inv(a,b,n>>1);
			static ll a1[N],a2[N];
			memset(a1,0,sizeof(a1[0])*(n<<1));
			memset(a2,0,sizeof(a2[0])*(n<<1));
			memcpy(a1,a,sizeof(a1[0])*n);
			memcpy(a2,b,sizeof(a2[0])*(n>>1));
			ntt(a1,n<<1,1);
			ntt(a2,n<<1,1);
			for(int i=0;i<n<<1;i++)
				a1[i]=a2[i]*(2-a1[i]*a2[i]%p)%p;
			ntt(a1,n<<1,-1);
			memcpy(b,a1,sizeof(a1[0])*n);
		}
		void ln(ll *a,ll *b,int n)
		{
			static ll a1[N],a2[N],a3[N];
			for(int i=1;i<n;i++)
				a1[i-1]=a[i]*i%p;
			a1[n-1]=0;
			inv(a,a2,n);
			mul(a1,a2,a3,n-1,n-1,n-1);
			for(int i=1;i<n;i++)
				b[i]=a3[i-1]*fp(i,p-2)%p;
			b[0]=0;
		}
		void exp(ll *a,ll *b,int n)
		{
			if(n==1)
			{
				b[0]=1;
				return;
			}
			exp(a,b,n>>1);
			static ll a1[N],a2[N],a3[N];
			memset(b+(n>>1),0,sizeof(b[0])*(n>>1));
			ln(b,a3,n);
			memset(a1,0,sizeof(a1[0])*n);
			memset(a2,0,sizeof(a2[0])*n);
			memcpy(a1,b,sizeof(a1[0])*(n>>1));
			for(int i=0;i<(n>>1);i++)
				a2[i]=a[(n>>1)+i]-a3[(n>>1)+i];
			ntt(a1,n,1);
			ntt(a2,n,1);
			for(int i=0;i<n;i++)
				a1[i]=a1[i]*a2[i]%p;
			ntt(a1,n,-1);
			memcpy(b+(n>>1),a1,sizeof(a1[0])*(n>>1));
		}
	}
	ll inv[N],fac[N],ifac[N];
	ll f[N],g[N],w[N];
	void solve()
	{
		if(_z==1)
		{
			ans=fp(n,n-2)*fp(n,n-2)%p;
			return;
		}
		z--;
		ntt::init();
		fac[0]=fac[1]=ifac[0]=ifac[1]=inv[1]=1;
		for(int i=2;i<=n;i++)
		{
			fac[i]=fac[i-1]*i%p;
			inv[i]=-p/i*inv[p%i]%p;
			ifac[i]=ifac[i-1]*inv[i]%p;
		}
		ll ifacz=fp(z,p-2);
		
//		f[0]=1;
//		for(int i=1;i<=n;i++)
//			w[i]=fp(i,i);
//		for(int i=1;i<=n;i++)
//			for(int j=1;j<=i;j++)
//				f[i]=(f[i]+f[i-j]*fac[i-1]%p*ifac[i-j]%p*ifac[j-1]%p*n%p*n%p*w[j]%p*ifacz)%p;


		for(int i=1;i<=n;i++)
			g[i]=fp(i,i)*n%p*n%p*ifac[i-1]%p*ifacz%p*inv[i]%p;
		int k=1;
		while(k<=n)
			k<<=1;
		ntt::exp(g,f,k);
		ans=f[n]*fac[n]%p*fp(z,n)%p*fp(n,p-1-4)%p;
	}
}
int main()
{
	freopen("tree.in","r",stdin);
	freopen("tree.out","w",stdout);
	io::get(n);
	io::get(_z);
	io::get(op);
	z=fp(_z,p-2);
	if(op==0)
		solve0::solve();
	else if(op==1)
		solve1::solve();
	else
		solve2::solve();
	ans=ans*fp(_z,n)%p;
	ans=(ans%p+p)%p;
	io::put(ans);
	return 0;
}
posted @ 2019-02-03 22:11  ywwyww  阅读(1430)  评论(1编辑  收藏  举报