半在线卷积

分治fft,是一种基于分治的、可以在\(O(n\log^2n)\)的复杂度内求解一类形如卷积形式的递推式的每一项的算法。具体的来举个例子。

【模板】分治fft

题意:给定数组\(g\)的前\(n\)项(从\(1\)开始),求\(f\)的前\(n\)项。其中\(f\)满足:

\[f_i=\sum_{j=1}^if_{i-j}g_j,f_0=1 \]

答案对\(998244353\)取模。

(当然你可以多项式求逆)

首先看到这个长得像卷积一样的递推式很自然地想到多项式科技。联想我们曾经cdq分治的套路,cdq优化dp的时候也是要先递归左边,然后处理左边对右边的影响,最后递归右边。这个也一样。我们分治处理左边和右边,合并的时候做一次fft计算左边对右边的贡献。

另外由于这题有模数所以写了ntt。fft乘爆了不想调所以不写了。

分治fft分治ntt

void calc(int a[],int b[]){
    ntt(a,wl,1);ntt(b,wl,1);
    for(int i=0;i<wl;i++)a[i]=1ll*a[i]*b[i]%mod;
    ntt(a,wl,-1);
}
void solve(int l,int r){
    if(l>=r)return;
    int mid=(l+r)>>1;
    solve(l,mid);
    int len=r-l+1;wl=1;
    get(len-2);//处理长度和反转结果
    for(int i=0;i<wl;i++)a[i]=b[i]=0;
    for(int i=l;i<=mid;i++)a[i-l]=f[i];//f往前平移
    for(int i=1;i<=r-l;i++)b[i-1]=g[i];//把g的前若干项拉过来 因为f在l时g下标最大 为r-l 所以循环到r-l就行
    calc(a,b);
    for(int i=mid+1;i<=r;i++)f[i]=(f[i]+a[i-l-1])%mod;
    //我们一开始把f和g的下标向左平移了l+1 所以统计答案也要平移l+1
    solve(mid+1,r);
}
int main(){
    scanf("%d",&n);f[0]=1;n--;
    for(int i=1;i<=n;i++)scanf("%d",&g[i]);
    solve(0,n);
    for(int i=0;i<=n;i++)printf("%d ",f[i]);
    return 0;
}

刚才那算暴力实现,现在提供一个比较好的实现。而且易于封装。

参考了 qwaszx 的实现。

首先我们半在线卷积一般是分二叉,复杂度 \(O(n\log^2n)\)。接下来我们试着加几个叉。设分为 \(B\) 叉,那么每次需要计算两两间的贡献。并不需要 \(B^2\)\(O(\dfrac nB\log\dfrac nB)\) 的卷积,我们只需要记录所有点值就可以做到 \(B\) 次。那么复杂度为

\[T(n)=BT(\frac nB)+O(nB+n\log\frac nB) \]

\(B=O(\log n)\),得到复杂度 \(T(n)=O(\dfrac{n\log^2n}{\log\log n})\)

然而跑得快需要一些细节。

  1. 叉数:\(B\) 通常取 \(8\)\(16\)
  2. 小范围暴力。
  3. 子问题大小取为 \(2\) 的幂,不浪费 DFT 长度。
  4. 可以在外面预处理每一层 \(g\) 的点值。具体的,假设子问题大小为 \(d\),第 \(i\) 个块到第 \(j\) 个块 \(g\) 的贡献区间是 \([(j-i-1)d,(j-i+1)d]\),可以预处理。
  5. 每次多项式乘法是 \(2d-1\) 次多项式和 \(d-1\) 次多项式乘,并只取 \([d,2d-1]\) 项系数,可以循环卷积。

这样整个半在线卷积就只需要留在外面一个暴力的接口。而且暴力也基本是一样的(因为都是长的像 \(f_n=c_n\sum_{i=0}^{n-1}f_ig_{n-i}\) 这种的),所以你甚至只需要留一个 Relax 函数做接口来处理 \(c_n\),然后在外边的函数部分处理别的东西。

一份板子。

void cdq(int l,int r,int dep,int f[],int g[],void brute(int f[],int g[],int l,int r)){
	if(r-l+1<=32){
		brute(f,g,l,r);return;
	}
	if(l>=r)return;
	static int tmp[300010];
	int d=1<<((dep-1)*3);
	for(int i=0;;i++){
		int L=l+i*d,R=min(r,L+d-1);
		if(i){
			for(int j=0;j<(d<<1);j++)tmp[j]=0;
			for(int j=0;j<i;j++){
				for(int k=0;k<(d<<1);k++){
					int x=1ll*f1[dep][j][k]*g1[dep][i-j][k]%mod;
					tmp[k]=add(tmp[k],x);
				}
			}
			DIT(tmp,d<<1);
			for(int i=L;i<=R;i++)f[i]=add(f[i],tmp[i-L+d]);
		}
		cdq(L,R,dep-1,f,g,brute);
		if(R==r)return;
		for(int j=0;j<(d<<1);j++)f1[dep][i][j]=0;
		for(int j=L;j<=R;j++)f1[dep][i][j-L]=f[j];
		DIF(f1[dep][i],d<<1);
	}
}
void solve(int f[],int g[],int n,void brute(int f[],int g[],int l,int r)){
	if(n<=128){
		brute(f,g,0,n-1);return;
	}
	int len=1,dep=0;
	while(len<n)len<<=3,dep++;len>>=3;
	int *now=buf;
	for(int i=1;i<=dep;i++){
		int d=1<<((i-1)*3),mn=min((n-1)/d,7);
		for(int j=1;j<=mn;j++){
			int l=(j-1)*d+1,r=min(n-1,(j+1)*d-1);
			f1[i][j-1]=now;now+=d<<1;
			g1[i][j]=now;now+=d<<1;
			for(int k=0;k<(1<<d);k++)g1[i][j][k]=0;
			for(int k=l;k<=r;k++)g1[i][j][k-l+1]=g[k];
			DIF(g1[i][j],d<<1);
		}
	}
	cdq(0,n-1,dep,f,g,brute);
}

好了现在问题就是暴力怎么写。一些 \(O(n^2)\) 做常见幂级数运算的方法:

  1. 求逆:

\[F(x)G(x)=1 \]

\[g_0f_n=-\sum_{i=0}^{n-1}f_ig_{n-i} \]

  1. \(\ln\):

\[xF'(x)G(x)=G'(x) \]

\[nf_n=ng_n-\sum_{i=0}^{n-1}if_ig_{n-i} \]

  1. \(\exp\):

\[xF'(x)=xG'(x)F(x) \]

\[nf_n=\sum_{i=0}^{n-1}(n-i)f_ig_{n-i} \]

posted @ 2022-09-03 19:51  gtm1514  阅读(198)  评论(0编辑  收藏  举报