【BZOJ3625】【codeforces438E】小朋友和二叉树 生成函数+多项式求逆+多项式开根
首先,我们构造一个函数$G(x)$,若存在$k∈C$,则$[x^k]G(x)=1$。
不妨设$F(x)$为最终答案的生成函数,则$[x^n]F(x)$即为权值为$n$的神犇二叉树个数。
不难推导出,$[x^n]F(x)=\sum_{i=0}^{n}[x^i]G(x)\sum_{j=0}^{n-i}[x^j]F(j)\times [x^{n-j-i}]F(n-j-i)$。
(这个式子的意思就是说,不妨设当前根节点的权值为i,然后枚举左右两个子树的权值)
这个式子显然可以通过动规的方式去推,从而得出答案,优化后的时间复杂度是$O(n^2)$的,显然不行。
我们对式子进行化简,考虑到$[x^0]F(x)=1$,那么$F(x)=G(x)\times F^2(x)+1$。
通过移项,得到$G\times F^2-F+1=0$,是一个关于$F$的一元二次方程。
由于多项式$G(x)$是已知的,那么我们就可以通过求根公式解出$F(x)$。
套入求根公式,得到$F(x)=\frac{1±\sqrt{1-4G}}{2G}$。
考虑到$F(0)=1$,$G(0)=0$,那么$F(x)=\frac{1-\sqrt{1-4G}}{2G}$
分子分母同时乘上$1+\sqrt{1-4G}$,化简得到$F(x)=\frac{2}{1+\sqrt{1-4G}}$。
然后就是多项式开根+多项式求逆了。
#include<bits/stdc++.h> #define M (1<<18) #define L long long #define MOD 998244353 #define inv2 499122177 #define G 3 using namespace std; L pow_mod(L x,L k){ L ans=1; while(k){ if(k&1) ans=ans*x%MOD; x=x*x%MOD; k>>=1; } return ans; } void change(L a[],int n){ for(int i=0,j=0;i<n-1;i++){ if(i<j) swap(a[i],a[j]); int k=n>>1; while(j>=k) j-=k,k>>=1; j+=k; } } void NTT(L a[],int n,int on){ change(a,n); for(int h=2;h<=n;h<<=1){ L wn=pow_mod(G,(MOD-1)/h); for(int j=0;j<n;j+=h){ L w=1; for(int k=j;k<j+(h>>1);k++){ L u=a[k],t=w*a[k+(h>>1)]%MOD; a[k]=(u+t)%MOD; a[k+(h>>1)]=(u-t+MOD)%MOD; w=w*wn%MOD; } } } if(on==-1){ L inv=pow_mod(n,MOD-2); for(int i=0;i<n;i++) a[i]=a[i]*inv%MOD; reverse(a+1,a+n); } } void getinv(L a[],L b[],int n){ if(n==1){b[0]=pow_mod(a[0],MOD-2); return;} static L c[M],d[M]; memset(c,0,n<<4); memset(d,0,n<<4); getinv(a,c,n>>1); for(int i=0;i<n;i++) d[i]=a[i]; NTT(d,n<<1,1); NTT(c,n<<1,1); for(int i=0;i<(n<<1);i++) b[i]=(2*c[i]-d[i]*c[i]%MOD*c[i]%MOD+MOD)%MOD; NTT(b,n<<1,-1); for(int i=0;i<n;i++) b[n+i]=0; } void sqrt(L a[],L b[],int n){ if(n==1) return void(b[0]=1); sqrt(a,b,n>>1); static L invb[M],d[M]; memset(invb,0,M<<3); memset(d,0,M<<3); getinv(b,invb,n); for(int i=0;i<n;i++) d[i]=a[i]; NTT(b,n<<1,1); NTT(d,n<<1,1); NTT(invb,n<<1,1); for(int i=0;i<(n<<1);i++) b[i]=inv2*(b[i]+d[i]*invb[i]%MOD)%MOD; NTT(b,n<<1,-1); for(int i=0;i<n;i++) b[i+n]=0; } L a[M]={0},b[M]={0}; int main(){ int n,m; scanf("%d%d",&n,&m); int nn=1; while(nn<=m) nn<<=1; a[0]=1; for(int i=1;i<=n;i++){ int x; scanf("%d",&x); if(x<=m) a[x]=(a[x]-4+MOD)%MOD; } sqrt(a,b,nn); b[0]=(b[0]+1)%MOD; memset(a,0,nn<<3); getinv(b,a,nn); for(int i=1;i<=m;i++) printf("%lld\n",a[i]*2%MOD); }