分治FFT

分治FFT

考虑\(F(i)\),显然因为\(F\)是个卷积的形式(虽然我们不知道其中的某一部分),因此有:

\[F(x)=\sum_{i+j=x} F(i)G(j) \]

因此考虑我们计算出了前边一段的\(F\)值,可以通过乘上\(G\)中的一部分让这个\(F\)整体右移,如\(F(1)-F(3)\)卷上\(G(1)-G(3)\)就成为了\(F(4)-F(6)\)中得的一部分。

因此考虑分治。

考虑我们已经计算出了一段\([l,mid]\)中的真实\(F\)值,我们给右边的部分加上这些的贡献。

那么很显然就是\(F[l,mid]\)卷上一个\(G[0,r-l]\)就得到了\(F[mid,r]\)的一部分。

那么我们每次分治计算左区间后,\(NTT\)计算出左边对右边的贡献,然后累加上去即可。

对比代码理解更好哦QAQ。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<vector>
#define N 500005
#define pb push_back
#define g 3
#define gi 332748118
#define mod 998244353 
#define int long long
using namespace std;
int read()
{
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+(ch^48);ch=getchar();}
	return x*f;
}
int n,rev[N];
vector<int>F,G,S,T;
int ksm(int a,int b)
{
	int res=1;
	while(b)
	{
		if(b&1)res*=a,res%=mod;
		a*=a;a%=mod;b>>=1;
	}
	return res%mod;
}
void NTT(vector<int>&a,int limit,int type)
{
	for(int i=0;i<limit;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
	for(int mid=1;mid<limit;mid<<=1)
	{
		int Wn=ksm(type==1?g:gi,(mod-1)/(mid<<1));
		for(int j=0;j<limit;j+=(mid<<1))
		{
			int w=1;
			for(int k=0;k<mid;k++,w=(w*Wn%mod)%mod)
			{
				int x=a[j+k]%mod,y=w*a[j+k+mid]%mod;
				a[j+k]=(x+y)%mod;
				a[j+k+mid]=(x-y+mod)%mod;
			}
		}
	}
	if(type==-1)
	{
		int INV=ksm(limit,mod-2);
		for(int i=0;i<limit;i++)a[i]=a[i]*INV%mod;
	}
}
int get_limit(int x)
{
	int limit=1;while(limit<=x)limit<<=1;
	for(int i=0;i<limit;i++)rev[i]=((rev[i>>1]>>1)|((i&1)?limit>>1:0));
	return limit;
}
vector<int> operator*(vector<int>&a,vector<int>&b)
{
	int len=a.size()+b.size()-1;
	int limit=get_limit(len);
	a.resize(limit);b.resize(limit);
	NTT(a,limit,1);NTT(b,limit,1);
	for(int i=0;i<limit;i++)a[i]=a[i]*b[i]%mod;
	NTT(a,limit,-1);a.resize(len);
	return a;
}
void solve(int l,int r)
{
	if(l==r)return;
	int mid=(l+r)>>1;
	solve(l,mid);
	S.clear();T.clear();
	for(int i=l;i<=mid;i++)S.pb(F[i]),T.pb(G[i-l]);
	for(int i=mid+1;i<=r;i++)S.pb(0),T.pb(G[i-l]);
	S=S*T;
	for(int i=mid+1;i<=r;i++)F[i]=(F[i]+S[i-l])%mod;
	solve(mid+1,r);
}
signed main()
{
	n=read();G.pb(0);F.pb(1);
	for(int i=1;i<n;i++)G.pb(read());
	for(int i=1;i<n;i++)F.pb(0);
	solve(0,n-1);
	for(int i=0;i<n;i++)printf("%d ",F[i]);
	return 0;
}

posted @ 2021-03-03 20:24  shao0320  阅读(133)  评论(0编辑  收藏  举报
****************************************** 页脚Html代码 ******************************************