快速傅里叶变换(FFT)及快速数论变换(NTT)详解

前置知识

复数
单位根(下面讲)
多项式(下面讲)

单位根

前置知识:复数

在复平面上,所有与原点距离为 1 的点组成单位圆。
由于两个复数相乘的法则(辐角相加,模长相乘),在单位圆上的两个复数相乘还是在单位圆上(模长都是 1),且辐角为这两个复数相加。
所以对于 xn=1 这个方程来说,它的解在单位圆上,且辐角为 360n 的倍数,显然这样的复数有 n 个,我们把这样的 n 个复数叫做 n 次单位根,并将其中辐角最小的(1 除外)记作 ωn
所有 ωni 都是 n 次单位根。

单位根

上图就是 10 次单位根 在复平面上的表示。

单位根有三个引理:(证明显然

  1. 消去引理:ωknkm=ωnm
  2. 折半引理:(ωnk+n2)2=ωn2k=ωn2k
  3. 求和引理:i=0n1(ωnk)i=0

在 FFT 中,我们只需要用前两个。

多项式

这里我们讨论的多项式是只有一个变量 x 的多项式,即 F(x)=i=0naixi

多项式有两种表示方法:(设多项式有 n+1 项)

  1. 系数表示,即将多项式表示成:[a0,a2,,an],其中 aixi 项的系数
  2. 点值表示,即将多项式表示成:[(v0,F(v0)),(v1,F(v1)),,(vn,F(vn))],即将 n+1 个不同的值 vi 带入多项式得到 F(vi),然后用这 n+1 个二元组唯一确定一个多项式。(不考虑那种无用点值之类的东西)

容易发现,对于两个多项式 F(x)G(x),设 H(x)=F(x)×G(x),如果有 FG 的点值表示,那么 H 的点值表示也可以很快算出来:H(vi)=F(vi)×G(vi)

所以,FFT 要做的事就是将系数表示转换成点值表示(及其逆过程,这个后面再说)。

FFT

几个概念

  1. 离散傅里叶变换(Discrete Fourier Transform,DFT),即 将系数表示转换成点值表示,其中 vi=ωni
  2. 快速傅里叶变换(即 快速(离散)傅里叶变换,Fast (Discrete) Fourier Transform,FFT),即快速做 DFT。

思路

从这里开始,我们将 n 定为多项式的项数,而不是多项式的次数。

FFT 可以(且必须)直接计算 2k 个点值,所以我们将 n 取到比它大(或等于它)的最小的 2 的幂次。(n2k

设多项式为 F(x)=i=0n1aixi,并且令点值表示法为 [(v0,y0),(v1,y1),,(vn1,yn1)]
容易发现 yi=j=0naj(vi)j=j=0najvij

我们首先将 F(x) 的系数 A=[a0,a1,,an1] 按奇偶分为 A[0]A[1],即:
A[0]=[a0,a2,,an2], A[1]=[a1,a3,,an1]

那么我们会发现 F(x)=F[0](x2)+xF[1](x2)读者自证不难
所以我们就有了一个分治想法:用 F[0]F[1] 的点值表示来算出 F 的点值表示。

但是这里会有一个问题:F[0]F[1] 的点值表示都只有 n2 个,所以合并后也只有 n2 个。
所以我们还需要用 F[0]F[1]n2 个点值来求出 F 的后 n2 个点值。

这怎么办呢?随便取 (vi,yi) 的值肯定是无法做到的,我们要取一些有特殊性质的值。现在就要用到我们之前说的单位根了。

我们取 vi=ωni,这有什么用呢?我们代两个值进去看看:(记 k=k+n2

F(vk)=F(ωnk)=F[0]((ωnk)2)+ωnkF[1]((ωnk)2)=F[0](ωn2k)+ωnkF[1](ωn2k)F(vk)=F(ωnk)=F[0]((ωnk)2)+ωnkF[1]((ωnk)2)=F[0](ωn2k)ωnkF[1](ωn2k)

所以我们只需要用 F[0]F[1] 点值表示就可以直接求出 F 的点值表示啦。

然后怎么办呢?对于 n=1 的情况,F(x)=A0x,又由于 n=1,所以 x0=1,那么 F(x) 的点值表示就是 A0 啦。

非递归写法

由于 FFT 非常的常用,常数也很大,所以我们需要一定的卡常技巧。为了学习隔壁的 zkw 线段树我们发明出了 FFT 的非递归写法。

首先我们发现,只要求出所有 n=1 时的点值表示,其它的就可以轻松地用循环求出(把递归树画出来,用循环模拟过程。因为 n 是 2 的幂次,所以很好模拟)。

所以现在我们只需要求出 n=1 时的点值表示。考虑一个递归过程:(每次将上面一行 A 分成两个 A[0]A[1],将两个数组用 | 隔开)

F(8): a[0]   a[1]   a[2]   a[3]   a[4]   a[5]   a[6]   a[7]
F(4): a[0]   a[2]   a[4]   a[6] | a[1]   a[3]   a[5]   a[7]
F(2): a[0]   a[4] | a[2]   a[6] | a[1]   a[5] | a[3]   a[7]
F(1): a[0] | a[4] | a[2] | a[6] | a[1] | a[5] | a[3] | a[7]

我们再来看最后一行的下标,这次我们用二进制表示出来:(上面是十进制,下面是二进制)

 0   4   2   6   1   5   3   7
000 100 010 110 001 101 011 111

这个二进制有什么规律呢?我们将每个二进制倒过来并转成十进制看看:(上面是倒过来的二进制,下面是十进制)

000 001 010 011 100 101 110 111
 0   1   2   3   4   5   6   7

现在规律已经很明显了。我们要求的数在二进制意义下倒过来是顺序排列的,我们把这种序列叫做 逆二进制序

那么现在我们就解决了非递归版的 FFT。

代码在最后面。

逆 FFT(IFFT)

我们再回到之前的公式:yi=j=0najωnij
我们把它写成矩阵:y=Vn×a,其中 (Vn)i,j=ωnij
所以 a=Vn1y,其中 Vn1V 的逆矩阵。
通过某些我不会的方法算出来:(Vn1)i,j=ωnijn
然后就用 FFT 的方法,改一下公式就好了。

代码在最后面。

NTT

FFT 的优势很多,但是缺陷也很明显:需要用 double,所以会有精度问题,而且不能取模。

那么有没有其它的东西可以支持取模且没有精度问题呢?当然是有的,这就是快速数论变换(NTT,Number Theoretic Transform)。

这里的 NTT 解决的是模数为 998244353 的做法。如果是其它的模数,需要使用下面的 任意模数 NTT

首先观察一下 998244353 的性质:(数论不懂的可以看看我 这篇 博客)

  • 它是一个质数。
  • φ(998244353)=998244352=223×7×17
  • 它有 原根,其中一个是 3

什么是原根?

对于两个数 am,如果 gcd(a,m)=1,那么我们有 aφ(m)1(modm)(欧拉定理)。>

如果对于任意 0n<φ(m),满足 an1(modm),那么我们称 am 的原根。

我们回到 FFT 上来。当时我们为什么要令 xi=ωni 呢?因为它满足一下几条性质:

  • {ωni}(0i<n) 互不相同
  • ωkmkn=ωmn
  • (ωnk+n2)2=ωn2k=ωn2k

那么我们用原根是否也能做到这几条性质呢?答案是可以。

记原根为 g(模数为 998244353g=3),那么我们令 gn=gp1n,并令 xi=gni

那么我们可以证明这几条性质:(第一条上面 OI-Wiki 的链接里有,我 不会证 就不证了)

  • gkmkn=gkm(p1)kn=gm(p1)n=gmn
  • (gnk+n2)2=gn2k+n=gn2k=gn2k

那么我们就完美解决了所有性质,把 FFT 的板子套上去,换成 NTT 的公式就好了。

这时候有小可爱可能就会问了:你这个 gp1n 中指数 p1n 有没有可能不是整数啊?

这就要用到上面的质因数分解了:p1=998244352=223×7×17,又因为前面我们把 n 调整为 2 的幂次了,所以当 n223=83886088.3×106 时都是够用的啦。

另外,这种方法不止适用于 998244353,还有 4697620491004535809。(它们的原根都是 3 哦)

另外一些数也可以用这种方法,参见 原根表

代码在最后面。

任意模数 NTT(MTT)

如果模数不是上面的,或者在输入中给定,又怎么办呢?

这时候就需要 任意模数 NTT(any Module NTT)了。

对于任意模数,我们无法得到上面的性质了。怎么办呢?我们可以自己取模数!

具体地,我们取一些模数 p1,p2,pk 使得答案多项式的系数在 取模之前 不会超过 pi。一般来说取 3 个质数(9982443534697620491004535809)就够了。

然后我们先算出答案对每个 pi 取模的结果,利用 CRT 就可以求得答案对 pi 取模的结果。又因为答案小于 pi,所以这个结果就是答案。(可以在 CRT 的过程中就对原题模数取模,这样就不会爆 long long)然后将这个答案对题目中的模数取一次模就好了。

例题:洛谷 P4245

代码在最后面。

代码

都是非递归版的。

FFT 和 IFFT 代码

#include <cmath>
#include <algorithm>
const int N = /* ... */ + 5; // 2 * (n + m)
const double PI = acos(-1);
struct Complex {
	double x, y;
	Complex(double x_ = 0, double y_ = 0) : x(x_), y(y_) {}
};
Complex operator+(const Complex &a, const Complex &b) { return Complex(a.x + b.x, a.y + b.y); }
Complex operator-(const Complex &a, const Complex &b) { return Complex(a.x - b.x, a.y - b.y); }
Complex operator*(const Complex &a, const Complex &b) { return Complex(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); }
struct FFT {
	int rev[N];
	int limit;
	int init(int mx) {
		int w = 0;
		limit = 1;
		while(limit <= mx) limit <<= 1, w++; // w = log2(limit)
		for(int i = 0; i < limit; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (w - 1));
		return limit;
	}
	void trans(Complex *a, int type) { // type = 1 / -1
		for(int i = 0; i < limit; i++) if(i < rev[i]) std::swap(a[i], a[rev[i]]);
		for(int i = 1; i < limit; i <<= 1) {
			Complex wn = Complex(cos(PI / i), type * sin(PI / i));
			for(int j = 0; j < limit; j += (i << 1)) {
				Complex w(1, 0);
				for(int k = 0; k < i; k++, w = w * wn) {
					Complex x = a[j + k], y = w * a[j + i + k];
					a[j + k] = x + y;
					a[j + i + k] = x - y;
				}
			}
		}
		if(type == -1) for(int i = 0; i < limit; i++) a[i].x /= limit;
	}
	int trans(Complex *a, int n, int type) { int ret = init(n); trans(a, type); return ret; }
};
// 需要先 init 再调用第一个 trans,或者直接调用第二个 trans

NTT 代码

#include <algorithm>
typedef long long LL;
const int N = /* ... */ + 5; // 2 * (n + m)
template<LL mod, LL g> struct NTT {
	int rev[N];
	int limit;
	int init(int mx) {
		int w = 0;
		limit = 1;
		while(limit <= mx) limit <<= 1, w++; // w = log2(limit)
		for(int i = 0; i < limit; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (w - 1));
		return limit;
	}
	inline LL qpow(LL x, LL y) { LL ret = 1; while(true) { if(y & 1) ret = ret * x % mod; if(!(y >>= 1)) return ret; x = x * x % mod; } }
	inline LL inv(LL x) { return qpow(x, mod - 2); }
	void trans(LL *a, int type) { // type = 1 / -1
		LL invg = inv(g);
		for(int i = 0; i < limit; i++) if(i < rev[i]) std::swap(a[i], a[rev[i]]);
		for(int i = 1; i < limit; i <<= 1) {
			LL wn = qpow(type == -1 ? invg : g, (mod - 1) / (i << 1));
			for(int j = 0; j < limit; j += (i << 1)) {
				LL w = 1;
				for(int k = 0; k < i; k++, w = w * wn % mod) {
					LL x = a[j + k], y = w * a[j + i + k] % mod;
					a[j + k] = (x + y) % mod;
					a[j + i + k] = (x - y + mod) % mod;
				}
			}
		}
		if(type == -1) for(int i = 0; i < limit; i++) (a[i] *= inv(limit)) %= mod;
	}
	int trans(LL *a, int n, int type) { int ret = init(n); trans(a, type); return ret; }
};
// 需要先 init 再调用第一个 trans,或者直接调用第二个 trans

任意模数 NTT(MTT)例题代码

#include <cstdio>
#include <algorithm>

typedef long long LL;
const int N = 4e5 + 5;
const LL P1 = 469762049;
const LL P2 = 998244353;
const LL P3 = 1004535809;

template<LL mod, LL g> struct NTT {
	int rev[N];
	int limit;
	int init(int mx) {
		int w = 0;
		limit = 1;
		while(limit <= mx) limit <<= 1, w++; // w = log2(limit)
		for(int i = 0; i < limit; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (w - 1));
		return limit;
	}
	inline LL qpow(LL x, LL y) { LL ret = 1; while(true) { if(y & 1) ret = ret * x % mod; if(!(y >>= 1)) return ret; x = x * x % mod; } }
	inline LL inv(LL x) { return qpow(x, mod - 2); }
	void trans(LL *a, int type) { // type = 1 / -1
		LL invg = inv(g);
		for(int i = 0; i < limit; i++) if(i < rev[i]) std::swap(a[i], a[rev[i]]);
		for(int i = 1; i < limit; i <<= 1) {
			LL wn = qpow(type == -1 ? invg : g, (mod - 1) / (i << 1));
			for(int j = 0; j < limit; j += (i << 1)) {
				LL w = 1;
				for(int k = 0; k < i; k++, w = w * wn % mod) {
					LL x = a[j + k], y = w * a[j + i + k] % mod;
					a[j + k] = (x + y) % mod;
					a[j + i + k] = (x - y + mod) % mod;
				}
			}
		}
		if(type == -1) for(int i = 0; i < limit; i++) (a[i] *= inv(limit)) %= mod;
	}
	int trans(LL *a, int n, int type) { int ret = init(n); trans(a, type); return ret; }
};

LL a[N], b[N], ca[N], cb[N], ans1[N], ans2[N], ans3[N];
int n, m;
LL P;

LL qpow(LL x, LL y, LL mod) { LL ret = 1; while(true) { if(y & 1) ret = ret * x % mod; if(!(y >>= 1)) return ret; x = x * x % mod; } }
LL inv(LL x, LL mod) { return qpow(x, mod - 2, mod); }

NTT<P1, 3> ntt1;
NTT<P2, 3> ntt2;
NTT<P3, 3> ntt3;

int main() {
	scanf("%d%d%lld", &n, &m, &P);
	for(int i = 0; i <= n; i++) scanf("%lld", &a[i]);
	for(int i = 0; i <= m; i++) scanf("%lld", &b[i]);
	int limit = ntt1.init(n + m);
	ntt2.init(n + m), ntt3.init(n + m);
	for(int i = 0; i <= limit; i++) ca[i] = a[i], cb[i] = b[i];
	ntt1.trans(ca, 1), ntt1.trans(cb, 1);
	for(int i = 0; i <= limit; i++) ans1[i] = ca[i] * cb[i] % P1;
	ntt1.trans(ans1, -1);
	for(int i = 0; i <= limit; i++) ca[i] = a[i], cb[i] = b[i];
	ntt2.trans(ca, 1), ntt2.trans(cb, 1);
	for(int i = 0; i <= limit; i++) ans2[i] = ca[i] * cb[i] % P2;
	ntt2.trans(ans2, -1);
	for(int i = 0; i <= limit; i++) ca[i] = a[i], cb[i] = b[i];
	ntt3.trans(ca, 1), ntt3.trans(cb, 1);
	for(int i = 0; i <= limit; i++) ans3[i] = ca[i] * cb[i] % P3;
	ntt3.trans(ans3, -1);
	for(int i = 0; i <= n + m; i++) {
		// 三个质数可以手推 CRT
		// 看着这个推也可以 https://www.cnblogs.com/Memory-of-winter/p/10223844.html
		LL out = 0;
		LL tmp = ans1[i] + (ans2[i] - ans1[i] + P2) % P2 * inv(P1, P2) % P2 * P1;
		out = (tmp + (ans3[i] - tmp % P3 + P3) % P3 * inv(P1 * P2 % P3, P3) % P3 * P1 % P * P2 % P) % P;
		printf("%lld ", out);
	}
	puts("");
	return 0;
}

{% endnote %}

完结撒花~

https://www.cnblogs.com/zwfymqz/p/8244902.html
https://www.bilibili.com/video/BV1Y7411W73U
https://www.luogu.com.cn/problem/solution/P4245
https://www.cnblogs.com/Memory-of-winter/p/10223844.html
https://blog.csdn.net/zhouyuheng2003/article/details/85561887
https://www.cnblogs.com/Sakits/p/8416918.html
https://oi-wiki.org/math/poly/ntt/
https://www.cnblogs.com/zarth/p/7288456.html
http://www.longluo.me/blog/2022/05/01/Number-Theoretic-Transform/

posted @   XxEray  阅读(804)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· 单线程的Redis速度为什么快?
· 展开说说关于C#中ORM框架的用法!
· Pantheons:用 TypeScript 打造主流大模型对话的一站式集成库
· SQL Server 2025 AI相关能力初探
点击右上角即可分享
微信分享提示