多项式入门(FFT,NTT,MTT)

前言

又到了愉快的数学时间了!

由于蒟蒻作者数学特别烂,贴个友链然后直接上代码吧

本博客由某谷搬运过来,目的只是存板子。

\(update\ 2021.2.2\) 更新了部分代码,已经基本学懂,懒得更新讲解了。u1s1,到了高中后有了一定数学基础,就是比初中傻白甜的时候学得快。

\(update\ 2021.2.3\) 学懂了蝴蝶变换之后又更新了一波板子,不得不说,迭代版本真的快得离谱。

\(update\ 2021.7.27\) 折叠了代码,增加文章可读性。

FFT

友链

Aaplloo

OneInDark

练习

板题(UOJ)

板题(洛谷)

力(洛谷)

代码

板题代码

$Mine\ \text{(递归)}$
//12252024832524
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define TT template<typename T>
using namespace std; 

typedef long long LL;
const int MAXN = 1 << 21 | 5;
const double PI = acos(-1);
int lena,lenb;

LL Read()
{
	LL x = 0,f = 1;char c = getchar();
	while(c > '9' || c < '0'){if(c == '-')f = -1;c = getchar();}
	while(c >= '0' && c <= '9'){x = (x*10) + (c^48);c = getchar();}
	return x * f;
}
TT void Put1(T x)
{
	if(x > 9) Put1(x/10);
	putchar(x%10^48);
}
TT void Put(T x,char c = -1)
{
	if(x < 0) putchar('-'),x = -x;
	Put1(x);
	if(c >= 0) putchar(c);
}
TT T Max(T x,T y){return x > y ? x : y;}
TT T Min(T x,T y){return x < y ? x : y;}
TT T Abs(T x){return x < 0 ? -x : x;}

struct cp//complex
{
	double x,y;
	cp(){}
	cp(double x1,double y1){
		x = x1;
		y = y1;
	}
	cp operator + (const cp &A) const {return cp(x+A.x,y+A.y);}
	cp operator - (const cp &A) const {return cp(x-A.x,y-A.y);}
	cp operator * (const cp &A) const {return cp(x*A.x-y*A.y,x*A.y+y*A.x);}
}a[MAXN],b[MAXN];

void FFT(int len,cp * a,int f)
{
	if(len == 1) return;
	cp a1[len >> 1],a2[len >> 1];
	for(int i = 0;i < len;i += 2) a1[i >> 1] = a[i],a2[i >> 1] = a[i+1];
	FFT(len>>1,a1,f);
	FFT(len>>1,a2,f);
	cp w = cp(cos(2*PI/len),f*sin(2*PI/len)),k = cp(1,0);
	len >>= 1;
	for(int i = 0;i < len;++ i,k = k * w)
	{
		a[i] = a1[i] + k * a2[i];
		a[i+len] = a1[i] - k * a2[i];
	}
}

int main()
{
//	freopen(".in","r",stdin);
//	freopen(".in","w",stdout);
	lena = Read(); lenb = Read();
	for(int i = 0;i <= lena;++ i) a[i].x = Read();
	for(int i = 0;i <= lenb;++ i) b[i].x = Read();
	int len = 1;
	while(len <= lena + lenb) len <<= 1;
	FFT(len,a,1);
	FFT(len,b,1);
	for(int i = 0;i <= len;++ i) a[i] = a[i] * b[i];
	FFT(len,a,-1);
	for(int i = 0;i <= lena+lenb;++ i) Put((int)(a[i].x/len + 0.5),' ');
	return 0;
}
$Mine\ \text{(迭代)}$
//12252024832524
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define TT template<typename T>
using namespace std; 

typedef long long LL;
const int MAXN = 1 << 21 | 5;
const double PI = acos(-1);
int lena,lenb,len = 1,l = -1;
int rev[MAXN];

LL Read()
{
	LL x = 0,f = 1;char c = getchar();
	while(c > '9' || c < '0'){if(c == '-')f = -1;c = getchar();}
	while(c >= '0' && c <= '9'){x = (x*10) + (c^48);c = getchar();}
	return x * f;
}
TT void Put1(T x)
{
	if(x > 9) Put1(x/10);
	putchar(x%10^48);
}
TT void Put(T x,char c = -1)
{
	if(x < 0) putchar('-'),x = -x;
	Put1(x);
	if(c >= 0) putchar(c);
}
TT T Max(T x,T y){return x > y ? x : y;}
TT T Min(T x,T y){return x < y ? x : y;}
TT T Abs(T x){return x < 0 ? -x : x;}

struct cp
{
	double x,y;
	cp(){}
	cp(double x1,double y1){
		x = x1;
		y = y1;
	}
	cp operator + (const cp &A)const{return cp(x+A.x,y+A.y);}
	cp operator - (const cp &A)const{return cp(x-A.x,y-A.y);}
	cp operator * (const cp &A)const{return cp(x*A.x-y*A.y,x*A.y+y*A.x);}
}a[MAXN],b[MAXN];

void FFT(cp *a,int opt)
{
	for(int i = 0;i < len;++ i) if(i < rev[i]) swap(a[i],a[rev[i]]);
	for(int i = 1;i < len;i <<= 1)
	{
		cp w = cp(cos(PI/i),opt*sin(PI/i));
		for(int j = 0,p = i << 1;j < len;j += p)
		{
			cp s = cp(1,0);
			for(int k = 0;k < i;++ k,s = s * w)
			{
				cp X = a[j+k],Y = s * a[i+j+k];
				a[j+k] = X + Y;
				a[i+j+k] = X - Y;
			}
		}
	}
	if(opt == -1) for(int i = 0;i < len;++ i) a[i].x /= len;
}

int main()
{
//	freopen(".in","r",stdin);
//	freopen(".out","w",stdout);
	lena = Read(); lenb = Read();
	for(int i = 0;i <= lena;++ i) a[i].x = Read();
	for(int i = 0;i <= lenb;++ i) b[i].x = Read();
	while(len <= lena + lenb) len <<= 1,l++;
	for(int i = 0;i < len;++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l);
	FFT(a,1);
	FFT(b,1);
	for(int i = 0;i < len;++ i) a[i] = a[i] * b[i];
	FFT(a,-1);
	for(int i = 0;i <= lena+lenb;++ i) Put((int)(a[i].x + 0.5),' ');
	return 0;
}
$French's\space code\ \text{(迭代)}$
#include<iostream>
#include<algorithm>
#include<cmath>
using namespace std;
#define maxn 10000005
#define x first
#define y second
const double pi=acos(-1.0);
int n,m;
int limit=1;
pair<double,double> a[maxn];
pair<double,double> b[maxn];
int l;
int r[maxn];
pair<double,double> operator + (pair<double,double> a,pair<double,double> b)
{
	return make_pair(a.x+b.x,a.y+b.y);
}
pair<double,double> operator - (pair<double,double> a,pair<double,double> b)
{
	return make_pair(a.x-b.x,a.y-b.y);
}
pair<double,double> operator * (pair<double,double> a,pair<double,double> b)
{
	return make_pair(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);
}
#undef x
#undef y
void fft(pair<double,double> *a,int t)
{
	for(int i=0;i<limit;i++)
		if(i<r[i])
			swap(a[i],a[r[i]]);
	for(int mid=1;mid<limit;mid<<=1)
	{
		pair<double,double> wn=make_pair(cos(pi/mid),t*sin(pi/mid));
		for(int r=mid<<1,j=0;j<limit;j+=r)
		{
			pair<double,double> w=make_pair(1,0);
			for(int k=0;k<mid;k++,w=w*wn)
			{
				pair<double,double> x=a[j+k],y=w*a[j+mid+k];
				a[j+k]=x+y;
				a[j+mid+k]=x-y;
			}
		}
	}
}
void work()
{
	fft(a,1);
	fft(b,1);
	for(int i=0;i<=limit;i++)
		a[i]=a[i]*b[i];
	fft(a,-1);
	for(int i=0;i<=n+m;i++)
		cout<<(int)(a[i].first/limit+0.5)<<' ';
}
void perpare()
{
	while(limit<=n+m)
	{
		limit<<=1;
		l++;
	}
	for(int i=0;i<limit;i++)
		r[i]=(r[i>>1]>>1)|((i&1)<<(l-1));
}
int main()
{
	ios::sync_with_stdio(false);
	cin>>n>>m;
	for(int i=0;i<=n;i++)
		cin>>a[i].first;
	for(int i=0;i<=m;i++)
		cin>>b[i].first;
	perpare();
	work();
	return 0;
}

NTT

讲解

难得一见的讲解:

由于 \(FFT\) 会有精度问题,而且不能取模,所以 \(NTT\) 就诞生了。

我们只需将 \(FFT\) 中的 \(\omega\) 换成 \(NTT\) 中的模数的原根 \(g\) 就好了。

如果我们不需要取模,只需要找一个很大的模数就好了,这样取模就相当于没有取模。

当然最后除 \(len\) 的时候改为乘逆元就好了。

练习

板题(UOJ)

板题(洛谷)

其实你可以用 \(NTT\) 过所有 \(FFT\) 的题。 好像并不是

代码

板题代码

$Mine\ \text{(递归)}$
//12252024832524
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define TT template<typename T>
using namespace std; 

typedef long long LL;
const int MAXN = 1 << 21 | 5;
const int MOD = 998244353;
const int PHI = 998244352;
const int GINV = 332748118;
const int G = 3;
int lena,lenb;
int a[MAXN],b[MAXN];

LL Read()
{
	LL x = 0,f = 1;char c = getchar();
	while(c > '9' || c < '0'){if(c == '-')f = -1;c = getchar();}
	while(c >= '0' && c <= '9'){x = (x*10) + (c^48);c = getchar();}
	return x * f;
}
TT void Put1(T x)
{
	if(x > 9) Put1(x/10);
	putchar(x%10^48);
}
TT void Put(T x,char c = -1)
{
	if(x < 0) putchar('-'),x = -x;
	Put1(x);
	if(c >= 0) putchar(c);
}
TT T Max(T x,T y){return x > y ? x : y;}
TT T Min(T x,T y){return x < y ? x : y;}
TT T Abs(T x){return x < 0 ? -x : x;}

int qpow(int x,int y)
{
	int ret = 1;
	while(y){if(y & 1) ret = 1ll * ret * x % MOD;x = 1ll * x * x % MOD;y >>= 1;}
	return ret;
}
void NTT(int len,int *a,int f)
{
	if(len == 1) return;
	int a1[len >> 1],a2[len >> 1];
	for(int i = 0;i < len;i += 2) a1[i >> 1] = a[i],a2[i >> 1] = a[i+1];
	NTT(len>>1,a1,f);
	NTT(len>>1,a2,f);
	int w = qpow(f == 1 ? G : GINV,PHI/len),k = 1;
	len >>= 1;
	for(int i = 0;i < len;++ i,k = 1ll * k * w % MOD)
	{
		a[i] = (a1[i] + 1ll * k * a2[i]) % MOD;
		a[i+len] = (a1[i] - 1ll * k * a2[i]) % MOD;
	}
}

int main()
{
//	freopen(".in","r",stdin);
//	freopen(".in","w",stdout);
	lena = Read(); lenb = Read();
	for(int i = 0;i <= lena;++ i) a[i] = Read();
	for(int i = 0;i <= lenb;++ i) b[i] = Read();
	int len = 1;
	while(len <= lena + lenb) len <<= 1;
	NTT(len,a,1);
	NTT(len,b,1);
	for(int i = 0;i <= len;++ i) a[i] = 1ll * a[i] * b[i] % MOD;
	NTT(len,a,-1);
	const int invlen = qpow(len,MOD-2);
	for(int i = 0;i <= lena+lenb;++ i) Put((1ll * a[i] * invlen % MOD + MOD) % MOD,' ');
	return 0;
}
$Mine\ \text{(迭代)}$
//12252024832524
#include <cmath>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define TT template<typename T>
using namespace std; 

typedef long long LL;
const int MAXN = 1 << 21 | 5;
const int MOD = 998244353;
const int PHI = 998244352;
const int GINV = 332748118;
const int G = 3;
int lena,lenb,len = 1,l = -1;
int a[MAXN],b[MAXN],rev[MAXN];

LL Read()
{
	LL x = 0,f = 1;char c = getchar();
	while(c > '9' || c < '0'){if(c == '-')f = -1;c = getchar();}
	while(c >= '0' && c <= '9'){x = (x*10) + (c^48);c = getchar();}
	return x * f;
}
TT void Put1(T x)
{
	if(x > 9) Put1(x/10);
	putchar(x%10^48);
}
TT void Put(T x,char c = -1)
{
	if(x < 0) putchar('-'),x = -x;
	Put1(x);
	if(c >= 0) putchar(c);
}
TT T Max(T x,T y){return x > y ? x : y;}
TT T Min(T x,T y){return x < y ? x : y;}
TT T Abs(T x){return x < 0 ? -x : x;}

int qpow(int x,int y)
{
	int ret = 1;
	while(y){if(y & 1) ret = 1ll * ret * x % MOD;x = 1ll * x * x % MOD;y >>= 1;}
	return ret;
}
void NTT(int *a,int opt)
{
	for(int i = 0;i < len;++ i) if(i < rev[i]) swap(a[i],a[rev[i]]);
	for(int i = 1;i < len;i <<= 1)
	{
		int w = qpow(opt == 1 ? G : GINV,PHI / (i << 1));
		for(int j = 0,p = i << 1;j < len;j += p)
		{
			int mi = 1;
			for(int k = 0;k < i;++ k,mi = 1ll * mi * w % MOD)
			{
				int X = a[j+k],Y = 1ll * mi * a[i+j+k] % MOD;
				a[j+k] = (X + Y) % MOD;
				a[i+j+k] = (X - Y + MOD) % MOD;
			}
		}
	}
	int invlen = qpow(len,MOD-2);
	if(opt == -1) for(int i = 0;i < len;++ i) a[i] = 1ll * a[i] * invlen % MOD;
}

int main()
{
//	freopen(".in","r",stdin);
//	freopen(".out","w",stdout);
	lena = Read(); lenb = Read();
	for(int i = 0;i <= lena;++ i) a[i] = Read();
	for(int i = 0;i <= lenb;++ i) b[i] = Read();
	while(len <= lena + lenb) len <<= 1,l++;
	for(int i = 0;i < len;++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l);
	NTT(a,1);
	NTT(b,1);
	for(int i = 0;i < len;++ i) a[i] = 1ll * a[i] * b[i] % MOD;
	NTT(a,-1);
	for(int i = 0;i <= lena+lenb;++ i) Put(a[i],' ');
	return 0;
}

MTT

讲解

lvzelong2014巨佬

记得用long double

练习

板题(洛谷)

代码

板题(洛谷)代码
//12252024832524
#include <bits/stdc++.h>
#define TT template<typename T>
using namespace std;

typedef long long LL;
const int MAXN = 1 << 18 | 5,M = 32767;
const long double PI = acos(-1); 
int lena,lenb,mod;
int A[MAXN],B[MAXN],ans[MAXN];

LL Read()
{
	LL x = 0,f = 1; char c = getchar();
	while(c > '9' || c < '0'){if(c == '-') f = -1;c = getchar();}
	while(c >= '0' && c <= '9'){x = (x*10) + (c^48);c = getchar();}
	return x * f;
}
TT void Put1(T x)
{
	if(x > 9) Put1(x/10);
	putchar(x%10^48);
}
TT void Put(T x,char c = -1)
{
	if(x < 0) putchar('-'),x = -x;
	Put1(x); if(c >= 0) putchar(c);
}
TT T Max(T x,T y){return x > y ? x : y;}
TT T Min(T x,T y){return x < y ? x : y;}
TT T Abs(T x){return x < 0 ? -x : x;}

struct cp
{
	long double x,y;
	cp operator + (const cp &C)const{return cp{x+C.x,y+C.y};}
	cp operator - (const cp &C)const{return cp{x-C.x,y-C.y};}
	cp operator * (const cp &C)const{return cp{x*C.x-y*C.y,x*C.y+y*C.x};}
}a[MAXN],b[MAXN],I,ab,ak,bb,bk,da[MAXN],db[MAXN],dc[MAXN],dd[MAXN];
cp cj(cp C){return cp{C.x,-C.y};}

int rev[MAXN],len;
void pre(int L)
{
	int l = -1; len = 1;
	while(len <= L) len <<= 1,++l;
	for(int i = 1;i < len;++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l);
}
void FFT(cp *a,int f)//val = kM + b
{
	for(int i = 0;i < len;++ i) if(i < rev[i]) swap(a[i],a[rev[i]]);
	for(int i = 1;i < len;i <<= 1)
	{
		cp w = cp{cos(PI/i),f*sin(PI/i)};
		for(int j = 0,p = i << 1;j < len;j += p)
		{
			cp mi = cp{1,0};
			for(int k = 0;k < i;++ k,mi = mi * w)
			{
				cp X = a[j+k],Y = mi * a[i+j+k];
				a[j+k] = X+Y; 
				a[i+j+k] = X-Y; 
			}
		}
	}
	if(f == -1) for(int i = 0;i < len;++ i) a[i].x /= len,a[i].y /= len;
}
void mul(int *ma,int *mb,int *mc,int L,const int MOD)
{
	pre(L);
//	for(int i = 0;i < len;++ i) a[i] = b[i] = I;
	for(int i = 0;i <= lena;++ i) a[i] = cp{ma[i] & M,ma[i] >> 15};
	for(int i = 0;i <= lenb;++ i) b[i] = cp{mb[i] & M,mb[i] >> 15};
	FFT(a,1); FFT(b,1);
	for(int i = 0,j;i < len;++ i)
	{
		j = (len-i) & (len-1);
		ab = (a[i] + cj(a[j])) * cp{0.5,0};
		ak = (a[i] - cj(a[j])) * cp{0,-0.5};
		bb = (b[i] + cj(b[j])) * cp{0.5,0};
		bk = (b[i] - cj(b[j])) * cp{0,-0.5};
		da[i] = ab * bb; db[i] = ab * bk; dc[i] = ak * bb; dd[i] = ak * bk;
	}
	for(int i = 0;i < len;++ i) a[i] = da[i] + db[i] * cp{0,1};
	for(int i = 0;i < len;++ i) b[i] = dc[i] + dd[i] * cp{0,1};
	FFT(a,-1); FFT(b,-1);
	for(int i = 0;i <= lena+lenb;++ i)
	{
		LL v1 = ((LL)(a[i].x+0.5)) % MOD,v2 = ((LL)(a[i].y+0.5)) % MOD,v3 = ((LL)(b[i].x+0.5)) % MOD,v4 = ((LL)(b[i].y+0.5)) % MOD;
		mc[i] = (v1 + ((v2 + v3) << 15) + (v4 << 30)) % MOD;
	}
}

int main()
{
//	freopen(".in","r",stdin);
//	freopen(".out","w",stdout);
	lena = Read(); lenb = Read(); mod = Read();
	for(int i = 0;i <= lena;++ i) A[i] = Read();
	for(int i = 0;i <= lenb;++ i) B[i] = Read();
	mul(A,B,ans,lena+lenb,mod);
	for(int i = 0;i <= lena+lenb;++ i) Put(ans[i],i == lena+lenb ? '\n' : ' ');
	return 0;
}
posted @ 2021-02-02 15:32  皮皮刘  阅读(428)  评论(0编辑  收藏  举报