分治 fft 的一种 nlogn 做法

问题是给定 \(g_{1...n}\), 求 \(f_{0...n}\), 其中 \(f_0=1,f_i=\sum\limits_{j<i}f_jg_{i-j}\).
考虑分治 .
现在要计算 \(f_{0...r}\) , 设 \(mid=\lfloor\frac r2\rfloor\).
假设我们已经计算出 \(f_{0...mid}\).
那么我们先计算 \(f_{0...mid}\)\(f_{mid+1...r}\) 的贡献 , 这里直接乘 \(g\) 即可 .
然后需要计算 \(f_{mid+1...r}\) 对自己的贡献 , 然后发现 \(f_{mid+1...r}\) 对自己贡献的系数就是 \(f_{0...mid}\) . 证明考虑展开 \(f_i\) 表达式 .
那么就让 \(f_{mid+1...r}\) 乘上 \(f_{0...mid}\) 贡献到 \(f_{mid+1...r}\) 即可 .
时间复杂度为 \(T(n)=T(\frac n2)+n\log n=O(n\log n)\)

LuoguP4721code
#include<bits/stdc++.h>
typedef long long ll;
using namespace std;
int read()
{
	int ret=0;bool f=0;char c=getchar();
	while(c>'9'||c<'0')f|=(c=='-'),c=getchar();
	while(c>='0'&&c<='9')ret=(ret<<3)+(ret<<1)+(c^48),c=getchar();
	return f?-ret:ret;
}
const int mod=998244353;
int qpow(int a,int b){int ret=1;for(;b;b>>=1,a=(ll)a*a%mod)if(b&1)ret=(ll)ret*a%mod;return ret;}
int R[1<<21],W[1<<21];
int n;
struct poly
{
	vector<int>v;
	int&operator[](const int &i){return v[i];}
	int len(){return v.size();}
	void set(int l){v.resize(l);}
	void ntt(int L,int typ)
	{
		int n=1<<L;
		for(int i=0;i<n;i++)R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
		W[0]=1;W[1]=qpow(3,(mod-1)/n);if(typ==-1)W[1]=qpow(W[1],mod-2);
		for(int i=2;i<n;i++)W[i]=(ll)W[i-1]*W[1]%mod;
		set(n);
		for(int i=0;i<n;i++)if(R[i]>i)swap(v[R[i]],v[i]);
		for(int t=n>>1,d=1;d<n;d<<=1,t>>=1)
			for(int i=0;i<n;i+=(d<<1))
				for(int j=0;j<d;j++)
				{
					int tmp=(ll)W[t*j]*v[i+j+d]%mod;
					v[i+j+d]=(v[i+j]-tmp+mod)%mod;
					v[i+j]=(v[i+j]+tmp)%mod;
				}
		if(typ==-1){int inv=qpow(n,mod-2);for(int i=0;i<n;i++)v[i]=(ll)v[i]*inv%mod;}
	}
	void adjust(){while(len()>1&&v.back()==0)v.pop_back();}
	poly operator *(const poly &x)const
	{
		poly ret,tmp0=*this,tmp1=x;
		int L=ceil(log2(tmp0.len()+tmp1.len())),n=1<<L;
		tmp0.ntt(L,1);tmp1.ntt(L,1);
		ret.set(n);
		for(int i=0;i<n;i++)ret[i]=(ll)tmp0[i]*tmp1[i]%mod;
		ret.ntt(L,-1);ret.adjust();return ret;
	}
}f,g;
void solve(int r)
{
	if(r==0){f[r]=1;return;}
	int mid=r>>1;
	solve(mid);
	poly tmp0,tmp1;
	tmp0.set(mid+1);
	for(int i=0;i<=mid;i++)tmp0[i]=f[i];
	tmp1.set(r+1);
	for(int i=0;i<=r;i++)tmp1[i]=g[i];
	tmp0=tmp0*tmp1;
	for(int i=mid+1;i<=r;i++)f[i]=tmp0[i];
	tmp0.set(mid+1);
	for(int i=0;i<=mid;i++)tmp0[i]=f[i];
	tmp1.set(r+1);
	for(int i=0;i<=mid;i++)tmp1[i]=0;
	for(int i=mid+1;i<=r;i++)tmp1[i]=f[i];
	tmp0=tmp0*tmp1;
	for(int i=mid+1;i<=r;i++)f[i]=tmp0[i];
}
int main()
{
	n=read();
	g.set(n);
	for(int i=1;i<n;i++)g[i]=read();
	f.set(n);
	solve(n-1);
	for(int &i:f.v)printf("%d ",i);putchar('\n');
	return 0;
}
posted @ 2022-02-09 10:02  zero4338  阅读(246)  评论(11编辑  收藏  举报