[Codeforces438E][bzoj3625] 小朋友和二叉树 [多项式求逆+多项式开根]
题面
思路
首先,我们把这个输入的点的生成函数搞出来:
$C=\sum_{i=0}{lim}s_ixi$
其中$lim$为集合里面出现过的最大的数,$s_i$表示大小为$i$的数是否出现过
我们再设另外一个函数$F$,定义$F_k$表示总权值为$k$的二叉树个数
那么,一个二叉树显然可以通过两个子树(可以权值为0,也就是空子树)和一个节点构成
那么有如下求$F$的式子
$F_0=1$
$F_k=\sum_{i=0}^k s_i \sum_{j=0}^{k-i} F_j F_{k-i-j}$
显然这个式子是一个卷积的形式
我们把这个三个东西的卷积放到整个函数上来考虑(也就是把$F_kC_k$看做多项式的系数做卷积),可以得到:
$F=C\ast F\ast F +1$,这里的$\ast$是多项式乘法
由于多项式乘法和普通乘法一样满足各种定律之类的,所以我们可以解个一元二次方程
$C\ast F^2 - F + 1=0$
$F=\frac{1\pm \sqrt{1-4C}}{2C}$
先考虑上面是+的情况
此时${\lim_{x \to 0}}F(x)=+\infty$,舍去
再考虑-的情况
此时${\lim_{x \to 0}}F(x)=1$,OK
其实到这里就可以算了,一个多项式开方+多项式求逆就解决了
但是我们更进一步,再划一划这个式子,可以变成如下形式:
$F=\frac{2}{1+\sqrt{1-4C}}$
多项式开方+多项式求逆解决问题
Code:
这题超!级!卡!常!我这种人傻常数大的,只能参照爷稳稳的代码写一写了qwq
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cassert>
#include<cmath>
#include<ctime>
#define ll long long
#define MOD 998244353
using namespace std;
inline int read(){
int re=0,flag=1;char ch=getchar();
while(ch>'9'||ch<'0'){
if(ch=='-') flag=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9') re=(re<<1)+(re<<3)+ch-'0',ch=getchar();
return re*flag;
}
int n,m,K;
ll A[600010],B[600010],C[600010],x[600010],y[600010];
int g=3,ginv,inv2,r[600010],lim,cnt;
ll qpow(ll a,ll b){
ll re=1;
while(b){
if(b&1) re=re*a%MOD;
a=a*a%MOD;b>>=1;
}
return re;
}
ll add(ll a,ll b){
a+=b;
return (a>=MOD)?a-MOD:a;
}
ll dec(ll a,ll b){
a-=b;
return (a<0)?a+MOD:a;
}
void ntt(ll *a,ll type){
int i,j,k;ll x,y,w,wn,inv,mid;
for(i=0;i<lim;i++) if(i<r[i]) swap(a[i],a[r[i]]);
for(mid=1;mid<lim;mid<<=1){
wn=qpow((type==1)?g:ginv,(MOD-1)/(mid<<1));
for(j=0;j<lim;j+=(mid<<1)){
w=1;
for(k=0;k<mid;k++,w=w*wn%MOD){
x=a[j+k];y=a[j+k+mid]*w%MOD;
a[j+k]=add(x,y);
a[j+k+mid]=dec(x,y);
}
}
}
if(~type) return;
inv=qpow(lim,MOD-2);
for(i=0;i<lim;i++) a[i]=a[i]*inv%MOD;
}
void getinv(ll *a,ll *b,int len,int lena){
if(len==1){b[0]=qpow(a[0],MOD-2);return;}
assert(len==((len>>1)<<1));
int i,mid=len>>1;
getinv(a,b,mid,lena);
lim=1;cnt=0;r[0]=0;
while(lim<=(len<<1)) lim<<=1,cnt++;
for(i=0;i<lim;i++) r[i]=((r[i>>1]>>1)|((i&1)<<(cnt-1)));
for(i=0;i<len;i++) x[i]=a[i];
for(i=len;i<lim;i++) x[i]=0;
for(i=0;i<mid;i++) y[i]=b[i];
for(i=len;i<lim;i++) y[i]=0;
ntt(x,1);ntt(y,1);
for(i=0;i<lim;i++) x[i]=y[i]*(2ll-y[i]*x[i]%MOD+MOD)%MOD;
ntt(x,-1);
for(i=0;i<len;i++) b[i]=x[i];
}
void getrt(ll *a,ll *b,int len){
if(len==1){b[0]=1;return;}
int i,mid=len>>1;
getrt(a,b,mid);
memset(C,0,sizeof(C));getinv(b,C,len,len);
lim=1;cnt=0;r[0]=0;
while(lim<=(len<<1)) lim<<=1,cnt++;
for(i=0;i<lim;i++) r[i]=((r[i>>1]>>1)|((i&1)<<(cnt-1))),x[i]=a[i];
for(i=len;i<lim;i++) x[i]=C[i]=b[i]=0;
ntt(x,1);ntt(C,1);
for(i=0;i<lim;i++) x[i]=x[i]*C[i]%MOD;
ntt(x,-1);
for(i=0;i<len;i++) b[i]=add(b[i],x[i])*inv2%MOD;
}
void init(){
ginv=qpow(g,MOD-2);
inv2=qpow(2,MOD-2);
}
int main(){
init();
n=read();m=read();int i,t1;A[0]=K=1;
for(i=1;i<=n;i++) t1=read(),A[t1]=dec(A[t1],4);
while(K<=m) K<<=1;
getrt(A,B,K);
B[0]=add(B[0],1);
memset(C,0,sizeof(C));
getinv(B,C,K,K);
for(i=1;i<=m;i++) printf("%lld\n",(2*C[i]+MOD)%MOD);
}