快速傅里叶变换(FFT)及快速数论变换(NTT)详解
前置知识
复数
单位根(下面讲)
多项式(下面讲)
单位根
前置知识:复数
在复平面上,所有与原点距离为
由于两个复数相乘的法则(辐角相加,模长相乘),在单位圆上的两个复数相乘还是在单位圆上(模长都是
所以对于
所有
上图就是
单位根有三个引理:(证明显然)
- 消去引理:
- 折半引理:
- 求和引理:
在 FFT 中,我们只需要用前两个。
多项式
这里我们讨论的多项式是只有一个变量
多项式有两种表示方法:(设多项式有
- 系数表示,即将多项式表示成:
,其中 为 项的系数 - 点值表示,即将多项式表示成:
,即将 个不同的值 带入多项式得到 ,然后用这 个二元组唯一确定一个多项式。(不考虑那种无用点值之类的东西)
容易发现,对于两个多项式
所以,FFT 要做的事就是将系数表示转换成点值表示(及其逆过程,这个后面再说)。
FFT
几个概念
- 离散傅里叶变换(Discrete Fourier Transform,DFT),即 将系数表示转换成点值表示,其中
。 - 快速傅里叶变换(即 快速(离散)傅里叶变换,Fast (Discrete) Fourier Transform,FFT),即快速做 DFT。
思路
从这里开始,我们将
定为多项式的项数,而不是多项式的次数。
FFT 可以(且必须)直接计算
设多项式为
容易发现
我们首先将
那么我们会发现 读者自证不难。
所以我们就有了一个分治想法:用
但是这里会有一个问题:
所以我们还需要用
这怎么办呢?随便取
我们取
所以我们只需要用
然后怎么办呢?对于
非递归写法
由于 FFT 非常的常用,常数也很大,所以我们需要一定的卡常技巧。为了学习隔壁的 zkw 线段树我们发明出了 FFT 的非递归写法。
首先我们发现,只要求出所有
所以现在我们只需要求出 |
隔开)
F(8): a[0] a[1] a[2] a[3] a[4] a[5] a[6] a[7]
F(4): a[0] a[2] a[4] a[6] | a[1] a[3] a[5] a[7]
F(2): a[0] a[4] | a[2] a[6] | a[1] a[5] | a[3] a[7]
F(1): a[0] | a[4] | a[2] | a[6] | a[1] | a[5] | a[3] | a[7]
我们再来看最后一行的下标,这次我们用二进制表示出来:(上面是十进制,下面是二进制)
0 4 2 6 1 5 3 7
000 100 010 110 001 101 011 111
这个二进制有什么规律呢?我们将每个二进制倒过来并转成十进制看看:(上面是倒过来的二进制,下面是十进制)
000 001 010 011 100 101 110 111
0 1 2 3 4 5 6 7
现在规律已经很明显了。我们要求的数在二进制意义下倒过来是顺序排列的,我们把这种序列叫做 逆二进制序。
那么现在我们就解决了非递归版的 FFT。
代码在最后面。
逆 FFT(IFFT)
我们再回到之前的公式:
我们把它写成矩阵:
所以
通过某些我不会的方法算出来:
然后就用 FFT 的方法,改一下公式就好了。
代码在最后面。
NTT
FFT 的优势很多,但是缺陷也很明显:需要用 double
,所以会有精度问题,而且不能取模。
那么有没有其它的东西可以支持取模且没有精度问题呢?当然是有的,这就是快速数论变换(NTT,Number Theoretic Transform)。
这里的 NTT 解决的是模数为
首先观察一下
- 它是一个质数。
- 它有 原根,其中一个是
。
什么是原根?
对于两个数
和 ,如果 ,那么我们有 (欧拉定理)。> 如果对于任意
,满足 ,那么我们称 是 的原根。
我们回到 FFT 上来。当时我们为什么要令
互不相同
那么我们用原根是否也能做到这几条性质呢?答案是可以。
记原根为
那么我们可以证明这几条性质:(第一条上面 OI-Wiki 的链接里有,我 不会证 就不证了)
那么我们就完美解决了所有性质,把 FFT 的板子套上去,换成 NTT 的公式就好了。
这时候有小可爱可能就会问了:你这个
这就要用到上面的质因数分解了:
另外,这种方法不止适用于
另外一些数也可以用这种方法,参见 原根表。
代码在最后面。
任意模数 NTT(MTT)
如果模数不是上面的,或者在输入中给定,又怎么办呢?
这时候就需要 任意模数 NTT(any Module NTT)了。
对于任意模数,我们无法得到上面的性质了。怎么办呢?我们可以自己取模数!
具体地,我们取一些模数
然后我们先算出答案对每个 long long
)然后将这个答案对题目中的模数取一次模就好了。
例题:洛谷 P4245
代码在最后面。
代码
都是非递归版的。
FFT 和 IFFT 代码
#include <cmath>
#include <algorithm>
const int N = /* ... */ + 5; // 2 * (n + m)
const double PI = acos(-1);
struct Complex {
double x, y;
Complex(double x_ = 0, double y_ = 0) : x(x_), y(y_) {}
};
Complex operator+(const Complex &a, const Complex &b) { return Complex(a.x + b.x, a.y + b.y); }
Complex operator-(const Complex &a, const Complex &b) { return Complex(a.x - b.x, a.y - b.y); }
Complex operator*(const Complex &a, const Complex &b) { return Complex(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); }
struct FFT {
int rev[N];
int limit;
int init(int mx) {
int w = 0;
limit = 1;
while(limit <= mx) limit <<= 1, w++; // w = log2(limit)
for(int i = 0; i < limit; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (w - 1));
return limit;
}
void trans(Complex *a, int type) { // type = 1 / -1
for(int i = 0; i < limit; i++) if(i < rev[i]) std::swap(a[i], a[rev[i]]);
for(int i = 1; i < limit; i <<= 1) {
Complex wn = Complex(cos(PI / i), type * sin(PI / i));
for(int j = 0; j < limit; j += (i << 1)) {
Complex w(1, 0);
for(int k = 0; k < i; k++, w = w * wn) {
Complex x = a[j + k], y = w * a[j + i + k];
a[j + k] = x + y;
a[j + i + k] = x - y;
}
}
}
if(type == -1) for(int i = 0; i < limit; i++) a[i].x /= limit;
}
int trans(Complex *a, int n, int type) { int ret = init(n); trans(a, type); return ret; }
};
// 需要先 init 再调用第一个 trans,或者直接调用第二个 trans
NTT 代码
#include <algorithm>
typedef long long LL;
const int N = /* ... */ + 5; // 2 * (n + m)
template<LL mod, LL g> struct NTT {
int rev[N];
int limit;
int init(int mx) {
int w = 0;
limit = 1;
while(limit <= mx) limit <<= 1, w++; // w = log2(limit)
for(int i = 0; i < limit; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (w - 1));
return limit;
}
inline LL qpow(LL x, LL y) { LL ret = 1; while(true) { if(y & 1) ret = ret * x % mod; if(!(y >>= 1)) return ret; x = x * x % mod; } }
inline LL inv(LL x) { return qpow(x, mod - 2); }
void trans(LL *a, int type) { // type = 1 / -1
LL invg = inv(g);
for(int i = 0; i < limit; i++) if(i < rev[i]) std::swap(a[i], a[rev[i]]);
for(int i = 1; i < limit; i <<= 1) {
LL wn = qpow(type == -1 ? invg : g, (mod - 1) / (i << 1));
for(int j = 0; j < limit; j += (i << 1)) {
LL w = 1;
for(int k = 0; k < i; k++, w = w * wn % mod) {
LL x = a[j + k], y = w * a[j + i + k] % mod;
a[j + k] = (x + y) % mod;
a[j + i + k] = (x - y + mod) % mod;
}
}
}
if(type == -1) for(int i = 0; i < limit; i++) (a[i] *= inv(limit)) %= mod;
}
int trans(LL *a, int n, int type) { int ret = init(n); trans(a, type); return ret; }
};
// 需要先 init 再调用第一个 trans,或者直接调用第二个 trans
任意模数 NTT(MTT)例题代码
#include <cstdio>
#include <algorithm>
typedef long long LL;
const int N = 4e5 + 5;
const LL P1 = 469762049;
const LL P2 = 998244353;
const LL P3 = 1004535809;
template<LL mod, LL g> struct NTT {
int rev[N];
int limit;
int init(int mx) {
int w = 0;
limit = 1;
while(limit <= mx) limit <<= 1, w++; // w = log2(limit)
for(int i = 0; i < limit; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (w - 1));
return limit;
}
inline LL qpow(LL x, LL y) { LL ret = 1; while(true) { if(y & 1) ret = ret * x % mod; if(!(y >>= 1)) return ret; x = x * x % mod; } }
inline LL inv(LL x) { return qpow(x, mod - 2); }
void trans(LL *a, int type) { // type = 1 / -1
LL invg = inv(g);
for(int i = 0; i < limit; i++) if(i < rev[i]) std::swap(a[i], a[rev[i]]);
for(int i = 1; i < limit; i <<= 1) {
LL wn = qpow(type == -1 ? invg : g, (mod - 1) / (i << 1));
for(int j = 0; j < limit; j += (i << 1)) {
LL w = 1;
for(int k = 0; k < i; k++, w = w * wn % mod) {
LL x = a[j + k], y = w * a[j + i + k] % mod;
a[j + k] = (x + y) % mod;
a[j + i + k] = (x - y + mod) % mod;
}
}
}
if(type == -1) for(int i = 0; i < limit; i++) (a[i] *= inv(limit)) %= mod;
}
int trans(LL *a, int n, int type) { int ret = init(n); trans(a, type); return ret; }
};
LL a[N], b[N], ca[N], cb[N], ans1[N], ans2[N], ans3[N];
int n, m;
LL P;
LL qpow(LL x, LL y, LL mod) { LL ret = 1; while(true) { if(y & 1) ret = ret * x % mod; if(!(y >>= 1)) return ret; x = x * x % mod; } }
LL inv(LL x, LL mod) { return qpow(x, mod - 2, mod); }
NTT<P1, 3> ntt1;
NTT<P2, 3> ntt2;
NTT<P3, 3> ntt3;
int main() {
scanf("%d%d%lld", &n, &m, &P);
for(int i = 0; i <= n; i++) scanf("%lld", &a[i]);
for(int i = 0; i <= m; i++) scanf("%lld", &b[i]);
int limit = ntt1.init(n + m);
ntt2.init(n + m), ntt3.init(n + m);
for(int i = 0; i <= limit; i++) ca[i] = a[i], cb[i] = b[i];
ntt1.trans(ca, 1), ntt1.trans(cb, 1);
for(int i = 0; i <= limit; i++) ans1[i] = ca[i] * cb[i] % P1;
ntt1.trans(ans1, -1);
for(int i = 0; i <= limit; i++) ca[i] = a[i], cb[i] = b[i];
ntt2.trans(ca, 1), ntt2.trans(cb, 1);
for(int i = 0; i <= limit; i++) ans2[i] = ca[i] * cb[i] % P2;
ntt2.trans(ans2, -1);
for(int i = 0; i <= limit; i++) ca[i] = a[i], cb[i] = b[i];
ntt3.trans(ca, 1), ntt3.trans(cb, 1);
for(int i = 0; i <= limit; i++) ans3[i] = ca[i] * cb[i] % P3;
ntt3.trans(ans3, -1);
for(int i = 0; i <= n + m; i++) {
// 三个质数可以手推 CRT
// 看着这个推也可以 https://www.cnblogs.com/Memory-of-winter/p/10223844.html
LL out = 0;
LL tmp = ans1[i] + (ans2[i] - ans1[i] + P2) % P2 * inv(P1, P2) % P2 * P1;
out = (tmp + (ans3[i] - tmp % P3 + P3) % P3 * inv(P1 * P2 % P3, P3) % P3 * P1 % P * P2 % P) % P;
printf("%lld ", out);
}
puts("");
return 0;
}
{% endnote %}
完结撒花~
https://www.cnblogs.com/zwfymqz/p/8244902.html
https://www.bilibili.com/video/BV1Y7411W73U
https://www.luogu.com.cn/problem/solution/P4245
https://www.cnblogs.com/Memory-of-winter/p/10223844.html
https://blog.csdn.net/zhouyuheng2003/article/details/85561887
https://www.cnblogs.com/Sakits/p/8416918.html
https://oi-wiki.org/math/poly/ntt/
https://www.cnblogs.com/zarth/p/7288456.html
http://www.longluo.me/blog/2022/05/01/Number-Theoretic-Transform/
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· 单线程的Redis速度为什么快?
· 展开说说关于C#中ORM框架的用法!
· Pantheons:用 TypeScript 打造主流大模型对话的一站式集成库
· SQL Server 2025 AI相关能力初探