Loading

多项式乘法入门 FFT/NTT

首先要知道多项式是什么东西。多项式是形如 \(A(x)=a_0+a_1x+a_2x^2+a_3x^3+...+a_{N}x^N\) (即 \(A(x)=\sum_{i=0}^nx_i\))的整式。

两个多项式相加:\(C(x)=A(x)+B(x)=\sum\limits_{i=1}^N(a_i+b_i)x^i\)

两个多项式相减:\(C(x)=A(x)-B(x)=\sum\limits_{i=1}^N(a_i-b_i)x^i\)

两个次数分别为 \(n,m\) 的多项式相乘:\(C(x)=A(x)\times B(x)=\sum\limits_{i=0}^N\sum\limits_{j=0}^Ma_ib_jx^{i+j}\)。那么多项式 \(C(x)\) 的系数 \(c_i=\sum\limits_{j=0}^i a_jb_{i-j}\)

卷积:对于数列 \(A,B,C\),形如 \(c_i=\sum\limits_{j\oplus k=i}a_jb_k\) 的方式称为卷积(\(\oplus\) 为某一运算符)。如多项式乘法属于加卷积。

\(\text{FFT}\)

点值表示法

多项式 \(A(x)=\sum\limits_{i=1}^na_ix^i\) 的点值表示:带入任意一个 \(x\),即可得到对应的值。

\(A(x)=3x^2+5x+1\),带入 \(x=2\),可得 \(A(2)=3\times2^2+5\times 2+1=23\)

我们计算两个次数分别为 \(n,m\) 多项式相乘 \(A(x)B(x)\),可以先挑 \((n+m+1)\)\(x\) 计算点值,把对应位置相乘。如 \(A(x)=5x+3,B(x)=x+6\),选择 \(x=1,2,3\) 计算,\(A(1)=8,A(2)=13,A(3)=18,B(1)=7,B(2)=8,B(3)=9\),对应位置相乘可得 \(8\times 7=56,13\times 8=104,18\times 9=162\)

然后把这 \((n+m+1)\) 插值计算 \((n+m)\) 次多项式。按上面的例子,\(x=1,2,3\) 时值为 \(56,104,162\),可插值得到多项式 \(C(x)=5x^2+33x+18\)

接下来讲如何快速利用特殊的 \(x\) 来优化计算点值的过程。

复平面与单位根

在复平面上,横轴表示实部,纵轴表示虚部,以原点为起点的向量即可表示一个复数。

单位根:定义 \(\omega_n\) 为一个复数,使得 \((\omega_n)^n=1\)\(\omega_n\neq1\),比如 \(\omega_2=-1,\omega_4=-i\)。在复平面上,以原点为圆心作一个半径为 \(1\) 的圆,然后以数字 \(1\) 对应的向量为起点,逆时针旋转,把圆分成 \(n\) 等分,每等分划分的向量即为 \(1,\omega_n,\omega_n^2,\omega_n^3,...,\omega_n^{n-1}\)\(\omega_n^k=(\omega_n)^k\))。

计算:\(\omega_{n}^k=\cos k\frac{2\pi} n+i\sin k\frac{2\pi} n\)

几个性质:

  1. \(\omega_n^{n+k}=\omega_n^k\)

  2. \(\omega_n^{\frac n2}=-1\)

  3. \(\omega_n^{-k}=\omega_n^{n-k}\)

(自己画图就能明白)

快速傅里叶变换

我们希望带入 \(x=1,\omega_n,\omega_n^2,\omega_n^3,...,\omega_n^{n-1}\):。

对于多项式 \(A(x)\),我们把每一项系数按次数分成奇偶两类:

\[A(x)=a_0+a_1x+a_2x^2+a_3x^3+... \]

\[A_0(x)=a_0+a_2x+a_4x^2+... \]

\[A_1(x)=a_1+a_3x+a_5x^2+... \]

显然,\(A(x)=A_0(x^2)+xA_1(x^2)\)

带入 \(\omega_n^k(k<\frac n2)\) 得:

\[A(x)=A_0(\omega_n^{2k})+\omega_n^kA_1(\omega_n^{2k}) \]

带入 \(\omega_n^{k+\frac n2}(k<\frac n2)\) 得:

\[A(x)=A_0(\omega_n^{2k+n})+\omega_n^{k+\frac n 2}A_1(\omega_n^{2k+n}) \]

\[\space\space=A_0(\omega_n^{2k})-\omega_n^kA_1(\omega_n^{2k}) \]

然后我们发现带入两个值后,多项式 \(A_0\)\(A_1\) 里面的值都是 \(\omega_n^{2k}\),而只是右边的项的符号不同!!!!!!

注意 \(\omega_n^{2k}=\omega_{\frac n2}^k\),于是只需要处理 \(A_0,A_1\)\(x=1,\omega_{\frac n2},\omega_{\frac n2}^2,...,\omega_{\frac n2}^{\frac n2-1}\) 的点值即可!

这样可以直接分治,时间复杂度 \(O(n\log n)\)。当然,我们需要预处理出一个 \(>N+M\)\(N,M\) 为两个多项式次数)的 \(2\) 的幂 的数字 \(n\)

快速傅里叶逆变换

不要问我为什么,我们只需要把得到的点值对应位置相乘后,重新 \(\text{FFT}\) 一遍(过程一样),然后把 \(1\) 次项到 \(n-1\) 次项全部翻转一下即可(注意不是 \(0\) 次)。

模板

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll maxn=4e6+10;
ll n,m;
const double pi=acos(-1.0);
struct Complex
{
	double x,y;
	Complex (double xx=0,double yy=0) {x=xx; y=yy;}
	const Complex operator+(const Complex tmp) const
	{
		return (Complex){x+tmp.x,y+tmp.y};
	}
	const Complex operator-(const Complex tmp) const
	{
		return (Complex){x-tmp.x,y-tmp.y};
	}
	const Complex operator*(const Complex tmp) const
	{
		return (Complex){x*tmp.x-y*tmp.y,x*tmp.y+y*tmp.x};
	}
}a[maxn],b[maxn],c[maxn];
void fft(ll n,Complex *a)
{
	if(n==1) return;
	Complex a1[n>>1], a2[n>>1];
	ll m=n>>1;
	for(ll i=0;i<m;i++) a1[i]=a[2*i];
	for(ll i=0;i<m;i++) a2[i]=a[2*i+1];
	fft(m,a1);
	fft(m,a2);
	Complex W(cos(1.0*pi/m),sin(1.0*pi/m)), w(1,0);
	for(ll i=0;i<m;i++,w=w*W)
	{
		a[i]=a1[i]+w*a2[i];
		a[i+m]=a1[i]-w*a2[i];
	}
}
int main(){
	scanf("%lld%lld",&n,&m);
	for(ll i=0;i<=n;i++)
	{
		scanf("%lf",&a[i].x);
	}
	for(ll i=0;i<=m;i++)
	{
		scanf("%lf",&b[i].x);
	}
	ll s=n+m;
	ll p=1;
	while(p<=s) p<<=1;
	fft(p,a);
	fft(p,b);
	for(ll i=0;i<p;i++) c[i]=a[i]*b[i];
	fft(p,c);
	reverse(c+1,c+p);
	for(ll i=0;i<=s;i++) printf("%lld ",(ll)(c[i].x/p+0.5));
	return 0;
}

递归转递推(蝴蝶变换)

函数内开数组、递归…… 还有许多地方可以优化。

思考如何转递推。

我们尝试逐步模拟分类过程:

次数 \(0\) \(1\) \(2\) \(3\) \(4\) \(5\) \(6\) \(7\)
次数 \(0\) \(2\) \(4\) \(6\) \(1\) \(3\) \(5\) \(7\)
次数 \(0\) \(4\) \(2\) \(6\) \(1\) \(5\) \(3\) \(7\)

最后一行二进制分别为 \(000,100,010,110,001,101,011,111\)

反过来为 \(000,001,010,011,100,101,110,111\)

发现反过来是顺序的!

预处理 \(i\) 二进制反过来为 \(r_i\),按 \(r_i\) 排序。

然后直接递推即可。

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll maxn=4e6+10;
ll n,m,r[maxn];
const double pi=acos(-1);
struct Complex
{
	double x,y;
	Complex(double xx=0,double yy=0)
	{
		x=xx; y=yy;
	}
	const Complex operator+(const Complex tmp) const
	{
		return (Complex){x+tmp.x,y+tmp.y};
	}
	const Complex operator-(const Complex tmp) const
	{
		return (Complex){x-tmp.x,y-tmp.y};
	}
	const Complex operator*(const Complex tmp) const
	{
		return (Complex){x*tmp.x-y*tmp.y,x*tmp.y+y*tmp.x};
	}
}a[maxn],b[maxn],c[maxn],tmp[maxn];
void fft(ll n,Complex *a)
{
	for(ll i=0;i<(1<<n);i++)
		if(i<r[i]) swap(a[i],a[r[i]]);
	for(ll i=0;i<n;i++)
	{
		Complex W(cos(pi/(1<<i)),sin(pi/(1<<i)));
		tmp[0]=(Complex){1,0};
		for(ll j=1;j<(1<<i);j++) tmp[j]=tmp[j-1]*W;
		for(ll j=0;j<(1<<n);j++)
			if(!(j&(1<<i)))
			{
				Complex a1=a[j], a2=a[j+(1<<i)], t=a2*tmp[j&((1<<i)-1)];
				a[j]=a1+t;
				a[j+(1<<i)]=a1-t;
			}
	}
}
int main()
{
	scanf("%lld%lld",&n,&m);
	for(ll i=0;i<=n;i++)
	{
		scanf("%lf",&a[i].x);
	}
	for(ll i=0;i<=m;i++)
	{
		scanf("%lf",&b[i].x);
	}
	ll k=n+m, p=1;
	while((1<<p)<=k) ++p;
	for(ll i=1;i<(1<<p);i++)
		r[i]=(r[i>>1]>>1)|((i&1)<<p-1);
	fft(p,a);
	fft(p,b);
	for(ll i=0;i<(1<<p);i++) c[i]=a[i]*b[i];
	fft(p,c);
	reverse(c+1,c+(1<<p));
	for(ll i=0;i<=k;i++) printf("%lld ",(ll)(c[i].x/(1<<p)+0.5));
	return 0;
}

\(\text{NTT}\)

\(\text{FFT}\) 能够在 \(O(n\log n)\) 的时间内完成对两个多项式相乘,但局限性很大。首先大量的单位根会影响精度,其次不能取模。我们需要一种更高效的方法。

原根

根据欧拉定理,如果 \(a,p\) 互质,那么 \(a^{\varphi(p)}\equiv 1\pmod p\)

原根:对于互质的 \(a,p\),如果不存在 \(m\),满足 \(m<\varphi(p)\),且 \(a^m\equiv 1\pmod p\),那么称 \(a\)\(p\) 的一个原根。

我们把 \(p\) 写成 \(2^xb+1\) 的形式,若 \(n\) 是一个以 \(2\) 为底的幂(\(n\le2^x\)),记 \(g_n=a^{\frac{p-1}n}\)\(a\) 是原根),那么满足以下性质:

A: \(g_n^n=(a^{\frac{p-1}n})^n=a^{p-1}\equiv 1\pmod p\)

B: \(g_n^{\frac n 2}=a^{\frac{p-1}2}\equiv -1\pmod p\)

我们容易得到类似于单位根的几个性质

  1. \(g_n^{n+k}=g_n^n\times g_n^k\equiv1\times g_n^k\equiv g_n^k\pmod p\)

  2. \(g_n^{\frac n 2}\equiv -1\pmod p\)

  3. \(g_n^{-k}\equiv 1\times g_n^{-k}\equiv g_n^n\times g_n^{-k}\equiv g_n^{n-k}\pmod p\)

这和单位根不一模一样吗!!!只不过在模意义下而已。

模数限制

通常情况下,模数为 \(998244353=119\times 2^{23}+1\),此时原根可取 \(3\)\(10^9+7\) 基本做不了,因为 \(10^9+7=500000003\times 2^1+1\),而 \(1\) 太小。

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll maxn=4e6+10, mod=998244353, g=3;
ll n,m,a[maxn],b[maxn],c[maxn],r[maxn];
ll power(ll a,ll b)
{
	ll ans=1;
	while(b)
	{
		if(b&1) ans=ans*a%mod;
		a=a*a%mod;
		b>>=1;
	}
	return ans;
}
void ntt(ll n,ll *a)
{
	for(ll i=0;i<(1<<n);i++)
		if(i<r[i]) swap(a[i],a[r[i]]);
	for(ll i=0;i<n;i++)
	{
		ll G=power(g,(mod-1)/(1<<i+1));
		for(ll j=0;j<(1<<n);j+=(1<<i+1))
		{
			for(ll k=0,g=1;k<(1<<i);k++,g=g*G%mod)
			{
				ll a1=a[j+k], a2=a[(1<<i)+j+k], t=a2*g%mod;
				a[j+k]=(a1+t)%mod;
				a[(1<<i)+j+k]=(a1-t+mod)%mod;
			}
		}
	}
}
int main()
{
	scanf("%lld%lld",&n,&m);
	for(ll i=0;i<=n;i++) scanf("%lld",a+i);
	for(ll i=0;i<=m;i++) scanf("%lld",b+i);
	ll s=n+m, p=0;
	while((1<<p)<=s) ++p;
	for(ll i=1;i<(1<<p);i++)
		r[i]=(r[i>>1]>>1)|((i&1)<<p-1);
	ntt(p,a);
	ntt(p,b);
	for(ll i=0;i<(1<<p);i++) c[i]=a[i]*b[i]%mod;
	ntt(p,c);
	reverse(c+1,c+(1<<p));
	ll inv=power(1<<p,mod-2);
	for(ll i=0;i<=s;i++) printf("%lld ",c[i]*inv%mod);
	return 0;
}

分治 \(\text{FFT}\)

问题形式

已知 \(f_0=1\),而 \(f_i=\sum\limits_{j=0}^{i-1}f_j\times g_{i-j}\)\(g\) 是给定的,求 \(f_{0...n-1}\)

分治求解

考虑 cdq 分治,把 \([0,n-1]\) 分成 \([0,mid],[mid+1,n-1]\) 两部分。计算左边对右边的贡献,我们其实可以直接把 \([0,mid]\) 的多项式和 \(g\) 相乘,加到右边即可。

时间复杂度 \(O(n\log^2n)\)

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll maxn=4e5+10, mod=998244353, g=3;
ll n,a[maxn],f[maxn],rev[maxn],w1[maxn],w2[maxn],w[maxn];
ll power(ll a,ll b)
{
	ll ans=1;
	while(b)
	{
		if(b&1) ans=ans*a%mod;
		a=a*a%mod;
		b>>=1;
	}
	return ans;
}
void ntt(ll n,ll *a)
{
	for(ll i=0;i<(1<<n);i++)
	{
		rev[i]=(rev[i>>1]>>1)|((i&1)<<n-1);
		if(i<rev[i]) swap(a[i],a[rev[i]]);
	}
	for(ll i=0;i<n;i++)
	{
		ll G=power(g,(mod-1)/(1<<i+1));
		for(ll j=0;j<(1<<n);j+=(1<<i+1))
		{
			for(ll k=0,g0=1;k<(1<<i);k++,g0=g0*G%mod)
			{
				ll a1=a[j+k], a2=a[(1<<i)+j+k], t=a2*g0%mod;
				a[j+k]=(a1+t)%mod;
				a[(1<<i)+j+k]=(a1-t+mod)%mod;
			}
		}
	}
}
void cdq(ll l,ll r)
{
	if(l==r) return;
	ll mid=l+r>>1;
	cdq(l,mid);
	for(ll i=0;i<=r-l;i++) w1[i]=a[i];
	for(ll i=0;i<=mid-l;i++) w2[i]=f[l+i];
	ll s=r-l+mid-l, p=0;
	while((1<<p)<=s) ++p;
	ntt(p,w1);
	ntt(p,w2);
	for(ll i=0;i<(1<<p);i++) w[i]=w1[i]*w2[i]%mod;
	ntt(p,w);
	reverse(w+1,w+(1<<p));
	ll inv=power(1<<p,mod-2);
	for(ll i=mid+1;i<=r;i++)
	{
		f[i]=(f[i]+w[i-1-l]*inv)%mod;
	}
	for(ll i=0;i<(1<<p);i++) w1[i]=w2[i]=w[i]=0; 
	cdq(mid+1,r);
}
int main()
{
	scanf("%lld",&n);
	for(ll i=0;i<n-1;i++) scanf("%lld",a+i);
	f[0]=1;
	cdq(0,n-1);
	for(ll i=0;i<n;i++) printf("%lld ",f[i]);
	return 0;
}
posted @ 2023-08-06 19:24  Lgx_Q  阅读(22)  评论(0编辑  收藏  举报