FFT&NTT

Part 1 FFT

网上FFT和NTT的博客基本都是先铺一大堆前置知识,直接劝退,但其实FFT是个很好理解的东西,也不需要那些前置知识。

那些前置知识无非就是引出单位根,但我们可以直接定义单位根。

只需要记住,\(n\) 次单位根 \(\omega_n\) 是一个满足如下性质的数:

  1. \(\omega_{kn}^{ki}=\omega_n^i\)
  2. \(\omega_n^k=-\omega_n^{k+n/2}\)

如果要求的话,记住 \(\omega_{n}=\cos\frac{2\pi}{n}+i\sin \frac{2\pi}{n}\)。是个复数。
在计算多项式 \(f,g\) 的乘积时,FFT先将 \(f,g\) 分别转为点值表示,直接相乘得到结果的点值表示,再转回系数表示。为了方便,先把 \(f,g\) 的长度补到 \(2\) 的整数次幂。

FFT 的点值不是随便取的,是将 \(\omega_n^0,\omega_n^1...\omega_n^n-1\) 分别带入得到的。

如果要计算 \(A(\omega_n^k)\),设 \(A(x)=a_0+a_1x+a_2x^2...+a_{n-1}x^{n-1}\)。不妨把 \(A\) 的每个项按奇偶性分类,设两个多项式 \(A_1(x)=a_0+a_2x^2...+a_{n-2}x^{n-2},A_2(x)=a_1+a_3x^3...+a_{n-1}x^{n-1}\),那么根据上面的性质可以很轻松地得到:

\[\begin{align} A(\omega_n^k)&=A_1(\omega_n^{2k})+\omega_n^kA_2(\omega_n^{2k})\nonumber\\ &=A_1(\omega_{n/2}^k)+\omega_n^kA_2(\omega_{n/2}^k)\nonumber \end{align} \]

\[\begin{align} A(\omega_n^{k+n/2})&=A_1(\omega_n^{2k+n})+\omega_n^{k+n/2}A_2(\omega_n^{2k+n})\nonumber\\ &=A_1(\omega_n^{2k})+-\omega_n^kA_2(\omega_n^{2k})\nonumber\\ &=A_1(\omega_{n/2}^k)+-\omega_n^kA_2(\omega_{n/2}^k)\nonumber \end{align} \]

因此,只需要计算出 \(A_1,A_2\)\(\omega_{n/2}^k\) 处的取值即可 \(O(n)\) 求出 \(A\)\(\omega_n^k\) 处的取值。

\(A_1,A_2\) 都是规模减半的子问题,所以可以递归处理,时间复杂度 \(O(n\log n)\)

Part 2 IFFT

问题来了,怎么把一个点值表示出的多项式转回系数表示呢。

假设有一个多项式 \(A(x)=a_0+a_1x...+a_{n-1}x^{n-1}\),我们求出了它在 \(\omega_n^k\) 处的每个取值 \((b_0,b_1...b_{n-1})\)

设一个多项式 \(B(x)=b_0+b_1x...+b_{n-1}x^{n-1}\)。把单位根的倒数,即 \(\omega_n,\omega_n^{-1},\omega_n^{-2}...\omega_n^{-k}\) 带入,得到的点值依次是 \(c_0,c_1...c_{n-1}\)

\[\begin{align} c_x&=\sum\limits^{n-1}_{i=0}\omega_n^{-ix}b_i\nonumber\\ &=\sum\limits^{n-1}_{i=0}\omega_n^{-ix}\sum\limits^{n-1}_{j=0}\omega_n^{ij}a_j\nonumber\\ &=\sum\limits^{n-1}_{i=0}\sum\limits^{n-1}_{j=0}\omega_n^{ij-ix}a_j\nonumber\\ &=\sum\limits^{n-1}_{j=0}a_j\sum\limits^{n-1}_{i=0}\omega_n^{ij-ix}\nonumber \end{align} \]

\(\sum\limits^{n-1}_{i=0}\omega_n^{ij-ix}\)\(j=x\) 时为 \(1\),否则等比数列求和后这东西就是 \(0\)

因此,\(c_i=na_i\),即 \(a_i=\frac{c_i}{n}\)

现在就可以通过 \(b\) 求出 \(a\) 了。不难发现把单位根倒数带进去和把单位根带进去并没有什么区别。

Part 3 NTT

其实就是用原根顶替了单位根。

虽然不能这么说,但你可以直接认为 \(\omega_{n}\) 在模 \(p\) 意义下为 \(g^{\frac{p-1}{n}}\)\(n\) 被我们补为了 \(2^k\),所以常用的模数 998244353 你会发现它等于 \(2^{23}\) 乘上一坨。

然后NTT和FFT除了好写了点以为就没有区别了。

FFT的迭代优化和蝴蝶操作在NTT中效果貌似都不明显?

贴个NTT的代码:

#include <cstdio>
#define gc (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 100000, stdin), p1 == p2) ? EOF : *p1 ++)

char buf[100000], *p1, *p2;
inline int read() {
	char ch;
	int x = 0;
	while ((ch = gc) < 48);
	do x = x * 10 + ch - 48; while ((ch = gc) >= 48);
	return x;
}
const int mod = 998244353, G = 3;
inline void add(int &x, int y) {(x += y) >= mod && (x -= mod);}
inline int mns(int x, int y) {return x >= y ? x - y : x - y + mod;}

inline int qpow(int a, int b) {
	int ret = 1;
	while (b) {
		if (b & 1) ret = 1ll * ret * a % mod;
		a = 1ll * a * a % mod, b >>= 1;
	}
	return ret;
}
int omega[22][3100000], inv[22][3100000], Log[3100000], rev[3100000];
inline void swap(int &x, int &y) {
	int t = x; x = y; y = t;
}
void NTT(int *a, int n, int type) {
	for (int i = 0; i < n; ++ i)
		if (i < rev[i]) swap(a[i], a[rev[i]]);
	for (int i = 1; i < n; i <<= 1) {
		for (int j = 0; j < n; j += i << 1)
		for (int k = 0; k < i; ++ k) {
			int x = 1ll * (type ? inv[Log[i]][k] : omega[Log[i]][k]) * a[i + j + k] % mod;
			a[i + j + k] = mns(a[j + k], x), add(a[j + k], x);
		}
	}
}
int f[3100000], g[3100000];

int main() {
	int n = read(), m = read();
	for (int i = 0; i <= n; ++ i) f[i] = read();
	for (int i = 0; i <= m; ++ i) g[i] = read();
	int lim = 1, k = 0;
	while (lim < n + m + 1) lim <<= 1, ++ k;
	for (int i = 2; i <= lim; ++ i) Log[i] = Log[i >> 1] = 1; 
	for (int i = 0; i <= k; ++ i) {
		int org = qpow(G, (mod - 1) >> i), invorg = qpow(org, mod - 2);
		omega[i][0] = inv[i][0] = 1;
		for (int j = 1; j < 1 << i && j < lim; ++ j)
			omega[i][j] = 1ll * omega[i][j - 1] * org % mod, inv[i][j] = 1ll * inv[i][j - 1] * invorg % mod;
	}
	for (int i = 2; i <= lim; ++ i) Log[i] = Log[i >> 1] + 1;
	for (int i = 0; i < lim; ++ i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << k - 1;
	NTT(f, lim, 0), NTT(g, lim, 0);
	for (int i = 0; i < lim; ++ i) f[i] = 1ll * f[i] * g[i] % mod;
	NTT(f, lim, 1);
	for (int i = 0; i < lim; ++ i) f[i] = 1ll * f[i] * qpow(lim, mod - 2) % mod;
	for (int i = 0; i <= n + m; ++ i) printf("%d ", (int)(f[i] + mod) % mod);
	return 0;
}
posted @ 2022-06-26 14:05  zqs2020  阅读(31)  评论(0编辑  收藏  举报