【模板】分治 FFT

I.【模板】分治 FFT

作为多项式的第一题,这题还是挺好理解的。

首先,我们完全可以把\(n\)扩大到\(2\)的次幂,空余地方补上\(0\),并且答案不变。

然后,对于递推式\(f_i=\sum\limits_{j=1}^{i}f_{i-j}g_j\),我们如果再令\(g_0=0\)的话,显然这个\(j\)的下界是可以改成\(0\)的——虽然这会使式子中出现\(f_i\)本身,但是它的系数为\(0\),可以忽略。

于是我们现在就有递推式\(f_i=\sum\limits_{j=0}^{i}f_{i-j}g_j\)

介于题名为“分治FFT”,于是我们考虑用CDQ分治解决本题。

我们设当前分治区间为\([l,r)\),并且\(r-l\)是一个\(2\)的次幂,设为\(k\)。假如\(k=0\),即区间大小为\(1\)的话,显然应该直接返回;否则,我们考虑分治中点\(mid=l+2^{k-1}\)

我们可以先处理区间\([l,mid)\)的贡献。然后,考虑区间\([l,mid)\)中数对\([mid,r)\)的贡献。

我们有\(f_i=\sum\limits_{j=0}^{i}f_{i-j}g_j\)

现在,我们只保留式子中下标在\([l,mid)\)中的\(i-j\),以及下标在\([mid,r)\)中的\(i\),再改变一下枚举顺序,就有\(f_{i+j}=\sum\limits_{i\in[l,mid)}\sum\limits_{j\in[0,2^k)}f_i\times g_j\)。也因此,最终卷积出来的结果只有区间\([mid,r)\)中的\(f\)予以保留。

我们可以设一个\(a_i=f_{i+l},b_i=g_i\),且\(a\)的长度为\(2^{k-1}\)\(b\)的长度为\(2^k\)。则如果我们计算\(c=a\times b\)的话,\(f_{[mid,r)}\)所受到的贡献,就是\(c_{[2^{k-1},2^k)}\)的值。显然\(c\)可以通过NTT求出。

时间复杂度\(O(n\log^2n)\)

代码:

#include<bits/stdc++.h>
using namespace std;
const int mod=998244353;
const int G=3;
const int N=1<<20;
int n,f[N],g[N],rev[N],all,lim,LG,invlim;
int ksm(int x,int y){
	int rt=1;
	for(;y;x=(1ll*x*x)%mod,y>>=1)if(y&1)rt=(1ll*rt*x)%mod;
	return rt;
}
void NTT(int *a,int tp){
	for(int i=0;i<lim;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
	for(int md=1;md<lim;md<<=1){
		int rt=ksm(G,(mod-1)/(md<<1));
		if(tp==-1)rt=ksm(rt,mod-2);
		for(int stp=md<<1,pos=0;pos<lim;pos+=stp){
			int w=1;
			for(int i=0;i<md;i++,w=(1ll*w*rt)%mod){
				int x=a[pos+i],y=(1ll*w*a[pos+md+i])%mod;
				a[pos+i]=(x+y)%mod;
				a[pos+md+i]=(x-y+mod)%mod;
			}
		}
	}
	if(tp==-1)for(int i=0;i<lim;i++)a[i]=(1ll*a[i]*invlim)%mod;
}
int a[N],b[N],c[N];
void func(int *arr,int k){
	for(int i=0;i<lim;i++)a[i]=b[i]=0;
	for(int i=0;i<(1<<k);i++)a[i]=arr[i];
	for(int i=0;i<(2<<k);i++)b[i]=g[i];
	lim=(4<<k),LG=k+2,invlim=ksm(lim,mod-2);
	for(int i=0;i<lim;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(LG-1));
	NTT(a,1),NTT(b,1);
	for(int i=0;i<lim;i++)a[i]=1ll*a[i]*b[i]%mod;
	NTT(a,-1);
}
void CDQ(int *arr,int k){
	if(!k)return;
	CDQ(arr,k-1);
	func(arr,k-1);
	for(int i=0;i<(1<<(k-1));i++)(arr[(1<<(k-1))+i]+=a[(1<<(k-1))+i])%=mod;
	CDQ(arr+(1<<(k-1)),k-1);
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<n;i++)scanf("%d",&g[i]);
	f[0]=1;
	while((1<<all)<n)all++;
	CDQ(f,all);
	for(int i=0;i<n;i++)printf("%d ",f[i]);puts("");
	return 0;
}

posted @ 2021-04-01 19:36  Troverld  阅读(72)  评论(0编辑  收藏  举报