FFT & FNT 简要整理
几周前搞了搞……有点时间简要整理一下,诸多不足之处还请指出。
有哪些需要理解的地方?
- 点值表示:对于多项式 \(A(x)\),把 \(n\) 个不同的 \(x\) 代入,会得出 \(n\) 个不同的 \(y\),在坐标系内就是 \(n\) 个不同的点,那么这 \(n\) 个点唯一确定该多项式
- 为什么引入单位根 \(\omega\) 作为变量 \(x\):若代入一些 \(x\) ,使每个 \(x\) 的若干次方等于 \(1\),就不用做全部的次方运算了
- 单位根的性质:于是可以分治实现 \(FFT\),复杂度降至 \(O(n log n)\)
- 共轭复数:复数 \(z=a+bi\) 的共轭复数为 \(a−bi\)(虚部取反)。一个多项式在分治的过程中乘上单位根的共轭复数,分治完的每一项除以 \(n\) 即为原多项式的每一项系数
- 初始化序列反转二进制数 : \(r[i] = (r[i >> 1] >> 1) | ((i \& 1) << (l - 1))\)
- 利用二进制序列优化 \(FFT\):每个位置分治后的最终位置为其二进制翻转后得到的位置
算法的步骤?
- 输入,将向量长度转化成二的幂次
- 初始化序列反转二进制数 : \(r[i] = (r[i >> 1] >> 1) | ((i \& 1) << (l - 1))\)
- 利用 \(FFT\) 将多项式转化为点值表示,如下:
- 利用处理出的二进制求出待处理序列
- 分治,每次找序列的一半,利用单位根计算(乘上 \(k\) 次幂),后对称到另一半
- 将点值表示相乘,利用 \(IFFT\) 将其转换回系数表达式,输出
- 多项式在分治的过程中乘上单位根的共轭复数,分治完的每一项除 \(n\) 即为原多项式的每一项系数
\(NTT\) 其实就是在模意义下做了些许改变,十分类似。 - \(NTT\) 模数:原根为 \(3\),有 \(469762049\),\(998244353\),\(1004535809\)
- 原根在此特殊意义下代替了单位根,除法相应地变成逆元的乘法运算
- 在模数并不是 \(NTT\) 模数时,采用 \(MTT\) 或三模数 \(NTT\) + \(CRT\)
\(FFT\) 和 \(FNT\) 板子?
// Fast_fourier_transform
#include <cmath>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const double pi = acos(-1.0);
const int maxn = 5000000 + 10;
int n, m, len, limit, rado, r[maxn];
struct Complex {
double x, y;
Complex(double tx = 0, double ty = 0) { x = tx, y = ty; }
} f_1[maxn], f_2[maxn];
inline Complex operator + (const Complex &a, const Complex &b) { return Complex(a.x + b.x, a.y + b.y); }
inline Complex operator - (const Complex &a, const Complex &b) { return Complex(a.x - b.x, a.y - b.y); }
inline Complex operator * (const Complex &a, const Complex &b) { return Complex(a.x * b.x - a.y * b.y, a.x * b.y + b.x * a.y); }
inline int read() {
register char ch = 0; register int w = 0, x = 0;
while( !isdigit(ch) ) w |= (ch == '-'), ch = getchar();
while( isdigit(ch) ) x = (x * 10) + (ch ^ 48), ch = getchar();
return w ? -x : x;
}
inline void Fast_fourior_transform(Complex *a, int type) {
for(int i = 0; i < limit; ++i) if( i < r[i] ) swap(a[i], a[r[i]]);
for(int mid = 1; mid < limit; mid = mid << 1) {
Complex Base_w = Complex(cos(pi / mid), type * sin(pi / mid));
for(int len = mid << 1, l = 0; l < limit; l = l + len) {
Complex w = Complex(1, 0);
for(int k = 0; k < mid; ++k, w = w * Base_w) {
Complex tmp_1 = a[l + k], tmp_2 = w * a[l + mid + k];
a[l + k] = tmp_1 + tmp_2, a[l + mid + k] = tmp_1 - tmp_2;
}
}
}
}
int main(int argc, char const *argv[])
{
freopen("..\\nanjolno.in", "r", stdin);
freopen("..\\nanjolno.out", "w", stdout);
scanf("%d%d", &n, &m);
for(int i = 0; i <= n; ++i) f_1[i].x = read();
for(int i = 0; i <= m; ++i) f_2[i].x = read();
len = n + m, limit = 1, rado = 0;
while( limit <= len ) limit = limit << 1, ++rado;
for(int i = 0; i < limit; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (rado - 1));
Fast_fourior_transform(f_1, 1);
Fast_fourior_transform(f_2, 1);
for(int i = 0; i < limit; ++i) f_1[i] = f_1[i] * f_2[i];
Fast_fourior_transform(f_1, -1);
for(int i = 0; i <= len; ++i) printf("%d\n", (int)(f_1[i].x / limit + 0.1));
fclose(stdin), fclose(stdout);
return 0;
}
// Fast_number_theory_transform
#include <cmath>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int mod = 998244353;
const int maxn = 5000000 + 10;
int n, m, len, limit, rado, f_1[maxn], f_2[maxn], r[maxn];
inline int read() {
register char ch = 0; register int w = 0, x = 0;
while( !isdigit(ch) ) w |= (ch == '-'), ch = getchar();
while( isdigit(ch) ) x = (x * 10) + (ch ^ 48), ch = getchar();
return w ? -x : x;
}
inline int Fast_pow(int a, int p) {
long long x = a, ans = 1ll;
for( ; p; x = x * x % mod, p = p >> 1) if( p & 1 ) ans = x * ans % mod;
return (int)ans;
}
inline void Fast_numbertheory_transform(int *a, int limit, int type) {
int rado = bit[limit];
for(int i = 0; i < limit; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (rado - 1));
for(int i = 0; i < limit; ++i) if( i < r[i] ) swap(a[i], a[r[i]]);
for(int mid = 1; mid < limit; mid = mid << 1) {
int Base_p = Fast_pow(3ll, (mod - 1) / (mid << 1));
if( type == -1 ) Base_p = Fast_pow(Base_p, mod - 2);
for(int l = 0, length = mid << 1; l < limit; l = l + length) {
for(int k = 0, p = 1; k < mid; ++k, p = 1ll * p * Base_p % mod) {
int x = a[l + k], y = 1ll * p * a[l + mid + k] % mod;
a[l + k] = (x + y) % mod, a[l + mid + k] = (x - y + mod) % mod;
}
}
}
if( type == -1 ) for(int i = 0; i < limit; ++i) a[i] = 1ll * a[i] * Fast_pow(limit, mod - 2) % mod;
}
int main(int argc, char const *argv[])
{
freopen("..\\nanjolno.in", "r", stdin);
freopen("..\\nanjolno.out", "w", stdout);
scanf("%d%d", &n, &m);
for(int i = 0; i <= n; ++i) f_1[i] = read();
for(int i = 0; i <= m; ++i) f_2[i] = read();
len = n + m, limit = 1, rado = 0;
while( limit <= len ) limit = limit << 1, ++rado;
for(int i = 0; i < limit; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (rado - 1));
Fast_numbertheory_transform(f_1, 1);
Fast_numbertheory_transform(f_2, 1);
for(int i = 0; i < limit; ++i) f_1[i] = 1ll * f_1[i] * f_2[i] % mod;
Fast_numbertheory_transform(f_1, -1);
for(int i = 0; i <= len; ++i) printf("%d ", f_1[i]);
fclose(stdin), fclose(stdout);
return 0;
}
流萤断续光,一明一灭一尺间,寂寞何以堪。