UOJ#335. 【清华集训2017】生成树计数 多项式,FFT,下降幂,分治
原文链接www.cnblogs.com/zhouzhendong/p/UOJ335.html
前言
CLY大爷随手切这种题。
日常被CLY吊打系列。
题解
首先从 pruffer 编码的角度考虑这个问题。
pruffer 编码的长度为 $n-2$ ,如果点 $i$ 在 pruffer 编码中出现了 $d_i - 1$ 次,那么点 $i$ 的度数就是 $d_i$ ,对答案的贡献次数就是 $\binom {n-2}{d_i}a_i ^ {d_i}$ 。
于是自然想到用 EGF 做这个题。设
$$f_k(x) = \sum_{i=0}^{n-2} a_k ^ i (i+1) ^ m \frac{x^i}{i!}\\ g_k = \sum_{i=0}^{n-2} a_k ^ i (i+1) ^ {2m} \frac {x^i} {i!}$$
则答案就等于
$$\left(\prod_{i=1}^n a_i \right) \times \sum_{i=1}^n g_i(x) \prod_{j=1,j\neq i}^{n} f_j(x)$$
这个 EGF 的 $n-2$ 次项的系数。(注意是多项式系数乘以 $(n-2)!$)
假设
$$(x + 1) ^ m = \sum_{j=0}^m v_j i ^ {\underline{j}}$$
我们考虑对 $f_k(x)$ 操作一波,把:
$$\begin{eqnarray*}f_k(x) &=& \sum_{i=0}^{n-2} \frac{(a_kx)^i}{i!}\sum_{j=0}^m v_j i^{\underline{j}}\\&=&\sum_{j=0}^m v_j (a_kx) ^ j \sum_{i=0}^{n-2}\frac{(a_kx) ^ i}{i!}\\&=& \sum_{j=0}^m v_j(a_kx) ^ j e ^ {a_kx} \pmod {x^n-2} \end{eqnarray*}$$
设
$$(x + 1) ^2m = \sum_{j=0}^2m V_j i ^ {\underline{j}}$$
同理,我们对 $g_k(x)$ 也做类似的操作,可以得到
$$ans = \left(\prod_{i=1}^n a_i \right) e ^ {x\sum_{i=1}^n a_i} \sum_{i=1}^n \left(\sum_{j=0}^{2m} V_j (a_ix)^j \right)\prod_{k=1,k\neq i }^n \left(\sum_{t=0}^{m} v_j (a_kx)^t \right)$$
分治 FFT 即可。
时间复杂度 $O(nm\log ^2 n)$ 。
代码
#include <bits/stdc++.h> #define clr(x) memset(x,0,sizeof (x)) #define For(i,a,b) for (int i=a;i<=b;i++) #define Fod(i,b,a) for (int i=b;i>=a;i--) #define pb(x) push_back(x) #define mp(x,y) make_pair(x,y) #define fi first #define se second #define _SEED_ ('C'+'L'+'Y'+'A'+'K'+'I'+'O'+'I') #define outval(x) printf(#x" = %d\n",x) #define outvec(x) printf("vec "#x" = ");for (auto _v : x)printf("%d ",_v);puts("") #define outtag(x) puts("----------"#x"----------") #define outarr(a,L,R) printf(#a"[%d...%d] = ",L,R);\ For(_v2,L,R)printf("%d ",a[_v2]);puts(""); using namespace std; typedef long long LL; typedef unsigned long long ULL; typedef vector <int> vi; LL read(){ LL x=0,f=0; char ch=getchar(); while (!isdigit(ch)) f|=ch=='-',ch=getchar(); while (isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar(); return f?-x:x; } const int N=30005,M=65,Len=1<<16,mod=998244353; int Pow(int x,int y){ int ans=1; for (;y;y>>=1,x=(LL)x*x%mod) if (y&1) ans=(LL)ans*x%mod; return ans; } void Add(int &x,int y){ if ((x+=y)>=mod) x-=mod; } void Del(int &x,int y){ if ((x-=y)<0) x+=mod; } int del(int x){ return x<0?x+mod:x; } int add(int x){ return x>=mod?x-mod:x; } namespace poly{ int R[Len],w[Len]; void init(int n,int d){ For(i,1,n-1) R[i]=(R[i>>1]>>1)|((i&1)<<(d-1)); w[0]=1,w[1]=Pow(3,(mod-1)/n); For(i,2,n-1) w[i]=(LL)w[i-1]*w[1]%mod; } void FFT(int *a,int n,int flag){ if (flag<0) reverse(w+1,w+n); For(i,0,n-1) if (i<R[i]) swap(a[i],a[R[i]]); for (int t=n>>1,d=1;d<n;d<<=1,t>>=1) for (int i=0;i<n;i+=d<<1) for (int j=0;j<d;j++){ int tmp=(LL)w[t*j]*a[i+j+d]%mod; a[i+j+d]=del(a[i+j]-tmp); Add(a[i+j],tmp); } if (flag<0){ reverse(w+1,w+n); int inv=Pow(n,mod-2); For(i,0,n-1) a[i]=(LL)a[i]*inv%mod; } } } using poly::FFT; int n,m; int a[N]; int C[M][M],S[M][M],Fac[N],Inv[N]; int v1[M],v2[M]; void prework(){ int n=M-1; For(i,0,n) C[i][0]=C[i][i]=1; S[0][0]=1; For(i,1,n) For(j,1,i){ C[i][j]=add(C[i-1][j-1]+C[i-1][j]); Add(S[i][j]=S[i-1][j-1],(LL)S[i-1][j]*j%mod); } n=m; For(i,0,n) For(j,0,i) Add(v1[j],(LL)C[n][i]*S[i][j]%mod); n=m*2; For(i,0,n) For(j,0,i) Add(v2[j],(LL)C[n][i]*S[i][j]%mod); n=N-1; for (int i=Fac[0]=1;i<=n;i++) Fac[i]=(LL)Fac[i-1]*i%mod; Inv[n]=Pow(Fac[n],mod-2); Fod(i,n,1) Inv[i-1]=(LL)Inv[i]*i%mod; } int f[N*M],g[N*M]; int Hash(int i,int j){ return (i-1)*(m*2+1)+j; } int f1[Len],f2[Len],g1[Len],g2[Len],f3[Len],g3[Len]; int Solve(int L,int R){ if (L==R) return m*2; int mid=(L+R)>>1; int l1=Solve(L,mid),l2=Solve(mid+1,R); int p1=Hash(L,0),p2=Hash(mid+1,0); int s,d; for (s=1,d=0;s<l1+l2+1;s<<=1,d++); poly::init(s,d); For(i,0,s-1) f1[i]=f2[i]=g1[i]=g2[i]=0; For(i,0,l1) f1[i]=f[i+p1],g1[i]=g[i+p1]; For(i,0,l2) f2[i]=f[i+p2],g2[i]=g[i+p2]; FFT(f1,s,1),FFT(g1,s,1); FFT(f2,s,1),FFT(g2,s,1); For(i,0,s-1){ f3[i]=(LL)f1[i]*f2[i]%mod; g3[i]=((LL)f1[i]*g2[i]+(LL)g1[i]*f2[i])%mod; } FFT(f3,s,-1),FFT(g3,s,-1); int pR=Hash(R+1,0); For(i,p1,pR-1) f[i]=g[i]=0; For(i,0,s-1) if (i+p1<pR) f[i+p1]=f3[i],g[i+p1]=g3[i]; else break; int len=pR-1; while (!f[len]&&!g[len]&&len>p1) len--; while (len-p1>n) f[len]=g[len]=0,len--; return len-p1; } int Exp[Len]; int main(){ n=read(),m=read(); For(i,1,n) a[i]=read(); prework(); For(i,1,n){ int tmp=1; For(j,0,m){ f[Hash(i,j)]=(LL)v1[j]*tmp%mod; tmp=(LL)tmp*a[i]%mod; } tmp=1; For(j,0,m*2){ g[Hash(i,j)]=(LL)v2[j]*tmp%mod; tmp=(LL)tmp*a[i]%mod; } } int len=Solve(1,n),Sum=0; For(i,1,n) Add(Sum,a[i]); For(i,0,n) Exp[i]=(LL)Inv[i]*Pow(Sum,i)%mod; int L=1,d=0; for (;L<len+n+1;L<<=1,d++); poly::init(L,d); FFT(Exp,L,1),FFT(g,L,1); For(i,0,L-1) g[i]=(LL)g[i]*Exp[i]%mod; FFT(g,L,-1); int ans=(LL)g[n-2]*Fac[n-2]%mod; For(i,1,n) ans=(LL)ans*a[i]%mod; cout<<ans<<endl; return 0; }