[BZOJ3625] [Codeforces Round #250]小朋友和二叉树
Description
我们的小朋友很喜欢计算机科学,而且尤其喜欢二叉树。
考虑一个含有n个互异正整数的序列c[1],c[2],...,c[n]。如果一棵带点权的有根二叉树满足其所有顶点的权值都在集合{c[1],c[2],...,c[n]}中,我们的小朋友就会将其称作神犇的。并且他认为,一棵带点权的树的权值,是其所有顶点权值的总和。
给出一个整数m,你能对于任意的s(1<=s<=m)计算出权值为s的神犇二叉树的个数吗?请参照样例以更好的理解什么样的两棵二叉树会被视为不同的。
我们只需要知道答案关于998244353(7172^23+1,一个质数)取模后的值。
Input
第一行有2个整数 n,m(1<=n<=10^5; 1<=m<=10^5)。
第二行有n个用空格隔开的互异的整数 c[1],c[2],...,c[n](1<=c[i]<=10^5)。
Output
输出m行,每行有一个整数。第i行应当含有权值恰为i的神犇二叉树的总数。请输出答案关于998244353(=7172^23+1,一个质数)取模后的结果。
Sample Input
2 3
1 2
Sample Output
1
3
9
Solution
话说多项式的代码是真的难写...
思路其实比较简单,设\(f(n)\)表示花\(n\)的代价得到的二叉树的个数,\(g(n)\)表示有没有代价为\(n\)的点,只能为\(0,1\)。
那么很简单得到\(dp\)方程:
\[f(n)=\sum_{i=1}^ng(i)\sum_{j=1}^{n-i}f(j)f(n-i-j)
\]
这里是枚举根节点填什么,然后两边分别怎么填。
其中\(f\)的边界为\(f(0)=1\)。
利用直觉生成函数法可得,把\(f,g\)写成生成函数的形式可得:
\[F(x)=\sum_{n=0}^{\infty} f(n)x^n,G(x)=\sum_{n=0}^{\infty} g(n)x^n
\]
然后把\(dp\)方程写成卷积形式:
\[F(x)=F^2(x)G(x)+1
\]
\(+1\)是表示第\(0\)项。
解得:
\[F(x)=\frac{2}{1\pm\sqrt{1-4G(x)}}
\]
注意到\([0]G(x)=0\),取\(-\)分母为\(0\),所以取正号:
\[F(x)=\frac{2}{1+\sqrt{1-4G(x)}}
\]
然后就是多项式求逆和开根的板子了。
注意求逆和开根不要用同样的数组,不然各种冲突...我调了好久...
#include<bits/stdc++.h>
using namespace std;
void read(int &x) {
x=0;int f=1;char ch=getchar();
for(;!isdigit(ch);ch=getchar()) if(ch=='-') f=-f;
for(;isdigit(ch);ch=getchar()) x=x*10+ch-'0';x*=f;
}
void print(int x) {
if(x<0) putchar('-'),x=-x;
if(!x) return ;print(x/10),putchar(x%10+48);
}
void write(int x) {if(!x) putchar('0');else print(x);putchar('\n');}
#define lf double
#define ll long long
const int maxn = 2e5+10;
const int inf = 1e9;
const lf eps = 1e-8;
const int mod = 998244353;
const int inv2 = 499122177;
int f[maxn],g[maxn],n,m,mxn,bit,N,w[maxn],rw[maxn],s[maxn],t[maxn],pos[maxn];
int qpow(int a,int x) {
int res=1;
for(;x;x>>=1,a=1ll*a*a%mod) if(x&1) res=1ll*res*a%mod;
return res;
}
void prepare() {
w[0]=1;w[1]=qpow(3,(mod-1)/mxn);
for(int i=2;i<=mxn;i++) w[i]=1ll*w[i-1]*w[1]%mod;
rw[0]=1,rw[1]=qpow(qpow(3,mod-2),(mod-1)/mxn);
for(int i=2;i<=mxn;i++) rw[i]=1ll*rw[i-1]*rw[1]%mod;
}
void ntt(int *r,int op) {
for(int i=1;i<N;i++) if(pos[i]>i) swap(r[i],r[pos[i]]);
for(int i=1,d=mxn>>1;i<N;i<<=1,d>>=1)
for(int j=0;j<N;j+=i<<1)
for(int k=0;k<i;k++) {
int x=r[j+k],y=1ll*r[i+j+k]*(op==1?w:rw)[k*d]%mod;
r[j+k]=(x+y)%mod,r[i+j+k]=(x-y+mod)%mod;
}
if(op==-1) {
int inv=qpow(N,mod-2);
for(int i=0;i<N;i++) r[i]=1ll*r[i]*inv%mod;
}
}
int tmp1[maxn],tmp2[maxn],tmp3[maxn];
void get_pos(int len) {
for(bit=0,N=1;N<len;N<<=1,bit++);
for(int i=1;i<N;i++) pos[i]=pos[i>>1]>>1|((i&1)<<(bit-1));
}
void poly_inv(int *r,int *b,int len) {
if(len==1) return b[0]=qpow(r[0],mod-2),void();
poly_inv(r,b,len>>1);
for(int i=0;i<len;i++) tmp1[i]=b[i],tmp2[i]=r[i];
get_pos(len<<1);
ntt(tmp1,1),ntt(tmp2,1);
for(int i=0;i<N;i++) b[i]=((2ll*tmp1[i]%mod-1ll*tmp2[i]*tmp1[i]%mod*tmp1[i]%mod)%mod+mod)%mod;
ntt(b,-1);
for(int i=len;i<N;i++) b[i]=0;
for(int i=0;i<len<<1;i++) tmp1[i]=tmp2[i]=0;
}
void poly_sqrt(int *r,int *b,int len) {
if(len==1) return b[0]=r[0],void();
poly_sqrt(r,b,len>>1);
poly_inv(b,tmp3,len);
get_pos(len<<1);
for(int i=0;i<len;i++) tmp2[i]=r[i];
ntt(tmp2,1),ntt(tmp3,1);
for(int i=0;i<N;i++) tmp3[i]=1ll*tmp3[i]*tmp2[i]%mod;
ntt(tmp3,-1);
for(int i=0;i<len;i++) b[i]=1ll*inv2*(b[i]+tmp3[i])%mod;
for(int i=0;i<len<<1;i++) tmp3[i]=tmp2[i]=0;
}
int main() {
read(n),read(m);
for(int i=1,x;i<=n;i++) read(x),x<=m?g[x]=1:0;
for(mxn=1;mxn<=m<<1;mxn<<=1);
prepare();
for(int i=1;i<=m;i++) g[i]=(mod-4*g[i])%mod;
g[0]=(g[0]+1)%mod;
poly_sqrt(g,s,mxn>>1);s[0]=(s[0]+1)%mod;
poly_inv(s,t,mxn>>1);
for(int i=1;i<=m;i++) write((2ll*t[i]%mod+mod)%mod);
return 0;
}