多项式乘法入门 FFT/NTT

首先要知道多项式是什么东西。多项式是形如 A(x)=a0+a1x+a2x2+a3x3+...+aNxN (即 A(x)=i=0nxi)的整式。

两个多项式相加:C(x)=A(x)+B(x)=i=1N(ai+bi)xi

两个多项式相减:C(x)=A(x)B(x)=i=1N(aibi)xi

两个次数分别为 n,m 的多项式相乘:C(x)=A(x)×B(x)=i=0Nj=0Maibjxi+j。那么多项式 C(x) 的系数 ci=j=0iajbij

卷积:对于数列 A,B,C,形如 ci=jk=iajbk 的方式称为卷积( 为某一运算符)。如多项式乘法属于加卷积。

FFT

点值表示法#

多项式 A(x)=i=1naixi 的点值表示:带入任意一个 x,即可得到对应的值。

A(x)=3x2+5x+1,带入 x=2,可得 A(2)=3×22+5×2+1=23

我们计算两个次数分别为 n,m 多项式相乘 A(x)B(x),可以先挑 (n+m+1)x 计算点值,把对应位置相乘。如 A(x)=5x+3,B(x)=x+6,选择 x=1,2,3 计算,A(1)=8,A(2)=13,A(3)=18,B(1)=7,B(2)=8,B(3)=9,对应位置相乘可得 8×7=56,13×8=104,18×9=162

然后把这 (n+m+1) 插值计算 (n+m) 次多项式。按上面的例子,x=1,2,3 时值为 56,104,162,可插值得到多项式 C(x)=5x2+33x+18

接下来讲如何快速利用特殊的 x 来优化计算点值的过程。

复平面与单位根#

在复平面上,横轴表示实部,纵轴表示虚部,以原点为起点的向量即可表示一个复数。

单位根:定义 ωn 为一个复数,使得 (ωn)n=1ωn1,比如 ω2=1,ω4=i。在复平面上,以原点为圆心作一个半径为 1 的圆,然后以数字 1 对应的向量为起点,逆时针旋转,把圆分成 n 等分,每等分划分的向量即为 1,ωn,ωn2,ωn3,...,ωnn1ωnk=(ωn)k)。

计算:ωnk=cosk2πn+isink2πn

几个性质:

  1. ωnn+k=ωnk

  2. ωnn2=1

  3. ωnk=ωnnk

(自己画图就能明白)

快速傅里叶变换#

我们希望带入 x=1,ωn,ωn2,ωn3,...,ωnn1:。

对于多项式 A(x),我们把每一项系数按次数分成奇偶两类:

A(x)=a0+a1x+a2x2+a3x3+...

A0(x)=a0+a2x+a4x2+...

A1(x)=a1+a3x+a5x2+...

显然,A(x)=A0(x2)+xA1(x2)

带入 ωnk(k<n2) 得:

A(x)=A0(ωn2k)+ωnkA1(ωn2k)

带入 ωnk+n2(k<n2) 得:

A(x)=A0(ωn2k+n)+ωnk+n2A1(ωn2k+n)

  =A0(ωn2k)ωnkA1(ωn2k)

然后我们发现带入两个值后,多项式 A0A1 里面的值都是 ωn2k,而只是右边的项的符号不同!!!!!!

注意 ωn2k=ωn2k,于是只需要处理 A0,A1x=1,ωn2,ωn22,...,ωn2n21 的点值即可!

这样可以直接分治,时间复杂度 O(nlogn)。当然,我们需要预处理出一个 >N+MN,M 为两个多项式次数)的 2 的幂 的数字 n

快速傅里叶逆变换#

不要问我为什么,我们只需要把得到的点值对应位置相乘后,重新 FFT 一遍(过程一样),然后把 1 次项到 n1 次项全部翻转一下即可(注意不是 0 次)。

模板

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll maxn=4e6+10;
ll n,m;
const double pi=acos(-1.0);
struct Complex
{
	double x,y;
	Complex (double xx=0,double yy=0) {x=xx; y=yy;}
	const Complex operator+(const Complex tmp) const
	{
		return (Complex){x+tmp.x,y+tmp.y};
	}
	const Complex operator-(const Complex tmp) const
	{
		return (Complex){x-tmp.x,y-tmp.y};
	}
	const Complex operator*(const Complex tmp) const
	{
		return (Complex){x*tmp.x-y*tmp.y,x*tmp.y+y*tmp.x};
	}
}a[maxn],b[maxn],c[maxn];
void fft(ll n,Complex *a)
{
	if(n==1) return;
	Complex a1[n>>1], a2[n>>1];
	ll m=n>>1;
	for(ll i=0;i<m;i++) a1[i]=a[2*i];
	for(ll i=0;i<m;i++) a2[i]=a[2*i+1];
	fft(m,a1);
	fft(m,a2);
	Complex W(cos(1.0*pi/m),sin(1.0*pi/m)), w(1,0);
	for(ll i=0;i<m;i++,w=w*W)
	{
		a[i]=a1[i]+w*a2[i];
		a[i+m]=a1[i]-w*a2[i];
	}
}
int main(){
	scanf("%lld%lld",&n,&m);
	for(ll i=0;i<=n;i++)
	{
		scanf("%lf",&a[i].x);
	}
	for(ll i=0;i<=m;i++)
	{
		scanf("%lf",&b[i].x);
	}
	ll s=n+m;
	ll p=1;
	while(p<=s) p<<=1;
	fft(p,a);
	fft(p,b);
	for(ll i=0;i<p;i++) c[i]=a[i]*b[i];
	fft(p,c);
	reverse(c+1,c+p);
	for(ll i=0;i<=s;i++) printf("%lld ",(ll)(c[i].x/p+0.5));
	return 0;
}

递归转递推(蝴蝶变换)#

函数内开数组、递归…… 还有许多地方可以优化。

思考如何转递推。

我们尝试逐步模拟分类过程:

次数 0 1 2 3 4 5 6 7
次数 0 2 4 6 1 3 5 7
次数 0 4 2 6 1 5 3 7

最后一行二进制分别为 000,100,010,110,001,101,011,111

反过来为 000,001,010,011,100,101,110,111

发现反过来是顺序的!

预处理 i 二进制反过来为 ri,按 ri 排序。

然后直接递推即可。

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll maxn=4e6+10;
ll n,m,r[maxn];
const double pi=acos(-1);
struct Complex
{
	double x,y;
	Complex(double xx=0,double yy=0)
	{
		x=xx; y=yy;
	}
	const Complex operator+(const Complex tmp) const
	{
		return (Complex){x+tmp.x,y+tmp.y};
	}
	const Complex operator-(const Complex tmp) const
	{
		return (Complex){x-tmp.x,y-tmp.y};
	}
	const Complex operator*(const Complex tmp) const
	{
		return (Complex){x*tmp.x-y*tmp.y,x*tmp.y+y*tmp.x};
	}
}a[maxn],b[maxn],c[maxn],tmp[maxn];
void fft(ll n,Complex *a)
{
	for(ll i=0;i<(1<<n);i++)
		if(i<r[i]) swap(a[i],a[r[i]]);
	for(ll i=0;i<n;i++)
	{
		Complex W(cos(pi/(1<<i)),sin(pi/(1<<i)));
		tmp[0]=(Complex){1,0};
		for(ll j=1;j<(1<<i);j++) tmp[j]=tmp[j-1]*W;
		for(ll j=0;j<(1<<n);j++)
			if(!(j&(1<<i)))
			{
				Complex a1=a[j], a2=a[j+(1<<i)], t=a2*tmp[j&((1<<i)-1)];
				a[j]=a1+t;
				a[j+(1<<i)]=a1-t;
			}
	}
}
int main()
{
	scanf("%lld%lld",&n,&m);
	for(ll i=0;i<=n;i++)
	{
		scanf("%lf",&a[i].x);
	}
	for(ll i=0;i<=m;i++)
	{
		scanf("%lf",&b[i].x);
	}
	ll k=n+m, p=1;
	while((1<<p)<=k) ++p;
	for(ll i=1;i<(1<<p);i++)
		r[i]=(r[i>>1]>>1)|((i&1)<<p-1);
	fft(p,a);
	fft(p,b);
	for(ll i=0;i<(1<<p);i++) c[i]=a[i]*b[i];
	fft(p,c);
	reverse(c+1,c+(1<<p));
	for(ll i=0;i<=k;i++) printf("%lld ",(ll)(c[i].x/(1<<p)+0.5));
	return 0;
}

NTT

FFT 能够在 O(nlogn) 的时间内完成对两个多项式相乘,但局限性很大。首先大量的单位根会影响精度,其次不能取模。我们需要一种更高效的方法。

原根#

根据欧拉定理,如果 a,p 互质,那么 aφ(p)1(modp)

原根:对于互质的 a,p,如果不存在 m,满足 m<φ(p),且 am1(modp),那么称 ap 的一个原根。

我们把 p 写成 2xb+1 的形式,若 n 是一个以 2 为底的幂(n2x),记 gn=ap1na 是原根),那么满足以下性质:

A: gnn=(ap1n)n=ap11(modp)

B: gnn2=ap121(modp)

我们容易得到类似于单位根的几个性质

  1. gnn+k=gnn×gnk1×gnkgnk(modp)

  2. gnn21(modp)

  3. gnk1×gnkgnn×gnkgnnk(modp)

这和单位根不一模一样吗!!!只不过在模意义下而已。

模数限制#

通常情况下,模数为 998244353=119×223+1,此时原根可取 3109+7 基本做不了,因为 109+7=500000003×21+1,而 1 太小。

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll maxn=4e6+10, mod=998244353, g=3;
ll n,m,a[maxn],b[maxn],c[maxn],r[maxn];
ll power(ll a,ll b)
{
	ll ans=1;
	while(b)
	{
		if(b&1) ans=ans*a%mod;
		a=a*a%mod;
		b>>=1;
	}
	return ans;
}
void ntt(ll n,ll *a)
{
	for(ll i=0;i<(1<<n);i++)
		if(i<r[i]) swap(a[i],a[r[i]]);
	for(ll i=0;i<n;i++)
	{
		ll G=power(g,(mod-1)/(1<<i+1));
		for(ll j=0;j<(1<<n);j+=(1<<i+1))
		{
			for(ll k=0,g=1;k<(1<<i);k++,g=g*G%mod)
			{
				ll a1=a[j+k], a2=a[(1<<i)+j+k], t=a2*g%mod;
				a[j+k]=(a1+t)%mod;
				a[(1<<i)+j+k]=(a1-t+mod)%mod;
			}
		}
	}
}
int main()
{
	scanf("%lld%lld",&n,&m);
	for(ll i=0;i<=n;i++) scanf("%lld",a+i);
	for(ll i=0;i<=m;i++) scanf("%lld",b+i);
	ll s=n+m, p=0;
	while((1<<p)<=s) ++p;
	for(ll i=1;i<(1<<p);i++)
		r[i]=(r[i>>1]>>1)|((i&1)<<p-1);
	ntt(p,a);
	ntt(p,b);
	for(ll i=0;i<(1<<p);i++) c[i]=a[i]*b[i]%mod;
	ntt(p,c);
	reverse(c+1,c+(1<<p));
	ll inv=power(1<<p,mod-2);
	for(ll i=0;i<=s;i++) printf("%lld ",c[i]*inv%mod);
	return 0;
}

分治 FFT

问题形式#

已知 f0=1,而 fi=j=0i1fj×gijg 是给定的,求 f0...n1

分治求解#

考虑 cdq 分治,把 [0,n1] 分成 [0,mid],[mid+1,n1] 两部分。计算左边对右边的贡献,我们其实可以直接把 [0,mid] 的多项式和 g 相乘,加到右边即可。

时间复杂度 O(nlog2n)

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll maxn=4e5+10, mod=998244353, g=3;
ll n,a[maxn],f[maxn],rev[maxn],w1[maxn],w2[maxn],w[maxn];
ll power(ll a,ll b)
{
	ll ans=1;
	while(b)
	{
		if(b&1) ans=ans*a%mod;
		a=a*a%mod;
		b>>=1;
	}
	return ans;
}
void ntt(ll n,ll *a)
{
	for(ll i=0;i<(1<<n);i++)
	{
		rev[i]=(rev[i>>1]>>1)|((i&1)<<n-1);
		if(i<rev[i]) swap(a[i],a[rev[i]]);
	}
	for(ll i=0;i<n;i++)
	{
		ll G=power(g,(mod-1)/(1<<i+1));
		for(ll j=0;j<(1<<n);j+=(1<<i+1))
		{
			for(ll k=0,g0=1;k<(1<<i);k++,g0=g0*G%mod)
			{
				ll a1=a[j+k], a2=a[(1<<i)+j+k], t=a2*g0%mod;
				a[j+k]=(a1+t)%mod;
				a[(1<<i)+j+k]=(a1-t+mod)%mod;
			}
		}
	}
}
void cdq(ll l,ll r)
{
	if(l==r) return;
	ll mid=l+r>>1;
	cdq(l,mid);
	for(ll i=0;i<=r-l;i++) w1[i]=a[i];
	for(ll i=0;i<=mid-l;i++) w2[i]=f[l+i];
	ll s=r-l+mid-l, p=0;
	while((1<<p)<=s) ++p;
	ntt(p,w1);
	ntt(p,w2);
	for(ll i=0;i<(1<<p);i++) w[i]=w1[i]*w2[i]%mod;
	ntt(p,w);
	reverse(w+1,w+(1<<p));
	ll inv=power(1<<p,mod-2);
	for(ll i=mid+1;i<=r;i++)
	{
		f[i]=(f[i]+w[i-1-l]*inv)%mod;
	}
	for(ll i=0;i<(1<<p);i++) w1[i]=w2[i]=w[i]=0; 
	cdq(mid+1,r);
}
int main()
{
	scanf("%lld",&n);
	for(ll i=0;i<n-1;i++) scanf("%lld",a+i);
	f[0]=1;
	cdq(0,n-1);
	for(ll i=0;i<n;i++) printf("%lld ",f[i]);
	return 0;
}

出处:https://www.cnblogs.com/Sktn0089/p/17609803.html

版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。

posted @   Lgx_Q  阅读(42)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 分享4款.NET开源、免费、实用的商城系统
· 全程不用写代码,我用AI程序员写了一个飞机大战
· MongoDB 8.0这个新功能碉堡了,比商业数据库还牛
· 白话解读 Dapr 1.15:你的「微服务管家」又秀新绝活了
· 记一次.NET内存居高不下排查解决与启示
more_horiz
keyboard_arrow_up light_mode palette
选择主题
menu
点击右上角即可分享
微信分享提示