【数学】【多项式】快速数论变换(NTT)
引入
求两个多项式的卷积
Description
给定两个多项式 \(F\left(x\right), G\left(x\right)\) 的系数表示法,求两个多项式的卷积。
如:
得到的结果即为:
直接由系数计算的复杂度是 \(\mathcal O\left(n ^ 2\right)\) 级别的,而我们发现在点值表示下,计算卷积的复杂度是 \(\mathcal O\left(n\right)\) 的,因此考虑把系数表示转化为点值表示然后再做乘法。
这样为了实现这个转化,我们引入了快速傅里叶变换(FFT)。通过把多项式按奇偶次分成两个关于 \(x^2\) 的多项式,并不断递归下去实现这个过程,并通过求得的一对 \(x\) 两个子多项式分别的值,一相加,一相减得到原多项式中 \(x\) 满足正负配对的一对点值。这个过程中为了满足 \(x\) 取值的正负配对,我们把 \(x\) 的取值扩充到复数域,取 \(n\) 次单位根。这样实现了由系数向点值的转化 DFT。
在由点值转化回系数的时候,我们从矩阵运算的角度,意识到 IDFT 矩阵应为 DFT 矩阵的逆矩阵,而 DFT 矩阵的逆矩阵又只需要对每一项取其倒数,并乘上 \(n^{-1}\)。因此我们在 DFT 上稍加修改就实现了 IDFT。
FFT 存在的一点不优美之处
根据 FFT 的过程,我们可知因为复数运算,FFT 对运算的精度要求很高,这也就影响了算法的时空开销。
而这个瓶颈似乎就是出在使用了复数运算上。
考虑是否能在 \(\mathbf R\) 甚至是 \(\mathbf Z\) 范围内找到一种能够替代 \(n\) 次单位根的某种神奇的元素。
它存在吗?当然存在。这就是快速数论变换(NTT)用到的原根。
一点数论知识和概念
阶
若对于 \(a, p\),存在正整数 \(l\) 使得 \(a^l \equiv 1 \left(\bmod p\right)\),则把满足这个条件的最小的正整数 \(l\) 称为 \(a\) 在模 \(p\) 下的阶。
举例:
当 \(a = 2, p = 5\) 时,因为 \(2 ^ 4 = 16 \equiv 1 \left(\bmod 5\right)\),可以证明,\(4\) 是满足该条件的最小正整数。因此 \(4\) 就是 \(2\) 在模 \(5\) 意义下的阶。
给定 \(a, p\) 求 \(l\) 的方法是一个典型的高次同余方程,可以使用 BSGS 算法求解。
原根
若 \(g\) 在模 \(p\) 意义下的阶是 \(\varphi\left(p\right)\),其中 \(\varphi\) 是欧拉函数,则 \(g\) 称为模 \(p\) 意义下的原根。
如何求一个原根呢?
关于原根,由如下的结论:
若 \(\gcd\left(g, p\right) = 1\),设 \(p_1, p_2,\cdots,p_k\) 是 \(\varphi\left(p\right)\) 所有的质因数,则 \(g\) 是模 \(p\) 的一个原根,当且仅当对于任意的 \(p_i\),都有 \(g^{\frac{\varphi\left(p\right)}{p_i}} \not\equiv 1 \left(\bmod p\right)\)。
可以证明,\(p\) 的最小原根在 \(p ^ {0.25}\) 级别。因此枚举原根的时间一般都是可以接受的。
快速数论变换(NTT)
原根的性质
根据阶和原根的定义,我们可以发现原根一个很优美的性质,那就是 \(\{g, g ^ 2, g ^ 3, \cdots, g ^ {\varphi\left(p\right)}\}\) 在模 \(p\) 下互不相同。次数若继续增大,则在模 \(p\) 下就会出现循环。这似乎与单位根的性质非常相似。
特殊的,若 \(p\) 是质数,\(\varphi\left(p\right) = p - 1\)。若 \(p - 1\) 的分解中含有较多的 \(2\) 这个因子,就可以把它的原根作拿来代替单位根了!
这个模数 \(p\) 一般取 \(998244353\left(2^{23} \times 7 \times 17 + 1\right)\) 即可。此时原根 \(g = 3\)。
只要多项式的长度不要求在 \(8 \times 10^6\) 以上,这个模数和原根基本就够用。
实现
为了实现 \(n\) 次单位根乘 \(n\) 次方会出现循环的性质,联系上文提到的原根的性质,直接用 \(g^{\frac{\varphi\left(p\right)}{n}}\)(\(p\) 为质数时为 \(g^{\frac{p - 1}{n}}\))来替换 \(n\) 次单位根即可。
这样,我们在 FFT 上稍作修改,就实现了 NTT。
NTT 实现 IDFT 的时候,存在两种效果等价的方式,一种是直接对转换后的序列 reverse
一下,另一种是取原根的逆元。我选择了后者,在 \(p = 998244353\) 时这个数为 \(332748118\)。
代码实现
void NTT(LL a[], int len, int type) {
for(int i = 0; i < len; ++i) if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int h = 2; h <= len; h <<= 1) {
LL gn = qpow(type == 1 ? g : invg , ((Mod - 1) / (LL)h));
for(int j = 0; j < len; j += h) {
LL gk = 1;
for(int k = j; k < j + h / 2; ++k) {
LL e = a[k], o = gk * a[k + h / 2] % Mod;
a[k] = (e + o) % Mod; a[k + h / 2] = ((e - o) % Mod + Mod) % Mod;
gk = gk * gn % Mod;
}
}
}
if(type == -1) {
LL inv = qpow(len, Mod - 2);
for(int i = 0; i < len; ++i) a[i] = a[i] * inv % Mod;
}
}
其他
完整代码
#include <bits/stdc++.h>
#define LL long long
template <typename Temp> inline void read(Temp & res) {
Temp fh = 1; res = 0; char ch = getchar();
for(; !isdigit(ch); ch = getchar()) if(ch == '-') fh = -1;
for(; isdigit(ch); ch = getchar()) res = (res << 3) + (res << 1) + (ch ^ '0');
res = res * fh;
}
using namespace std;
const int Maxn = 2097200;
const LL Mod = 998244353, g = 3, invg = 332748118;
inline LL qpow(LL A, LL P) {
LL res = 1;
while(P) {
if(P & 1) res = res * A % Mod;
A = A * A % Mod;
P >>= 1;
}
return res;
}
int rev[Maxn];
void Init(int len) {
for(int i = 1; i < len; ++i) {
rev[i] = rev[i >> 1] >> 1;
if(i & 1) rev[i] |= len >> 1;
}
}
void NTT(LL a[], int len, int type) {
for(int i = 0; i < len; ++i) if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int h = 2; h <= len; h <<= 1) {
LL gn = qpow(type == 1 ? g : invg , ((Mod - 1) / (LL)h));
for(int j = 0; j < len; j += h) {
LL gk = 1;
for(int k = j; k < j + h / 2; ++k) {
LL e = a[k], o = gk * a[k + h / 2] % Mod;
a[k] = (e + o) % Mod; a[k + h / 2] = ((e - o) % Mod + Mod) % Mod;
gk = gk * gn % Mod;
}
}
}
if(type == -1) {
LL inv = qpow(len, Mod - 2);
for(int i = 0; i < len; ++i) a[i] = a[i] * inv % Mod;
}
}
void polymul(LL a[], LL b[], int lenA, int lenB) {
int L = lenA + lenB, len = 1; while(L) {len <<= 1; L >>= 1;}
Init(len); NTT(a, len, 1); NTT(b, len, 1);
for(int i = 0; i < len; ++i) a[i] = a[i] * b[i];
NTT(a, len, -1);
}
int n, m;
LL A[Maxn], B[Maxn];
signed main() {
read(n); read(m);
for(int i = 0; i <= n; ++i) read(A[i]);
for(int i = 0; i <= m; ++i) read(B[i]);
polymul(A, B, n, m);
for(int i = 0; i <= n + m; ++i) printf("%lld ", A[i]);
return 0;
}
鸣谢
(排名不分先后)
OI-Wiki