多项式乘法

给定两个多项式

\(A(x) = \sum_{i=0}^{n-1}{a_i x^i}\)

\(B(x)=\sum^{m-1}_{i=0}{b_i x^i}\)

\(C(x) = A(x) \times B(x) = \sum_{i=0}^{n+m-1}{c_i x^i}\)

前置知识

单位根

\(n\) 次单位根即为满足 \(x^n=1\) 的 x。

由代数基本定理可得 \(n\) 次单位根应该有 \(n\) 个。

\(k\) 个记为 \(\omega^k_n\)

单位根的一些性质

\[\omega_n^k = \cos {\dfrac{2k\pi}{n}} + \mathcal{i} \sin{\dfrac{2k\pi}{n}} (k = 0, 1, \dots, n - 1) \\ \omega^{2k}_{2n} = \omega^{k}_{n} \\ \omega^{k + \frac{n}{2}}_{n}=-\omega^{k}_{n} \\ \omega^i_n \omega^j_n = \omega^{i+j}_n \\ (\omega^k_n)^j = \omega^{jk}_n \\ \]

从大佬博客偷来的图

原根的性质(\(p\) 为质数)

\[\forall i, j \in \Z, g^i \not \equiv g^j \left( mod p \right) \\ g^{p-1} \equiv 1 (modp) \\ \]

多项式的点表示法

众所周知,一个 \(n - 1\) 次多项式可以用 \(n\) 个互异的点表示出来。

这就叫做多项式的点表示法。

DFT

上面的点表示法启示我们了一种求多项式乘法的方法。

一个 \(n\) 次的多项式 \(A(x)\),对于取值 \(\{x_1, x_2, x_3, \dots, x_n\}\),可以求出\(\{y_{A_1}, y_{A_2}, y_{A_3}, \dots, y_{A_n}\}\)

同理一个 \(n\) 次的多项式 \(B(x)\),也可以求出相对应的 \(\{y_{B_1}, y_{B_2}, y_{B_3}, \dots, y_{B_n}\}\)

然后我们对于每一个 \(x_i\) 可以求出一个对应的 \(y_{C_i}=y_{A_i}\times y_{B_i}\)

得到了 \(\{y_{C_1}, y_{C_2}, y_{C_3}, \dots, y_{C_n}\}\),我们就可以用拉格朗日插值法求出多项式\(C(x)=A(x) \times B(x)\) 了。

当我们将取值 \(\{x_1, x_2, x_3, \dots, x_n\}\) 取为 \(\{ \omega^0_n, \omega^1_n, \omega^2_n, \dots, \omega^{n-1}_n \}\) 时我们就可以少算很多的次幂。

利用上述方法求多项式乘法的方法就叫做 DFT

FFT

上面的 DFT 虽然看起来很妙,但仍然是 \(\mathcal{O}(n^2)\) 的,其主要复杂度为求出对应的 \(\{y_{A_1}, y_{A_2}, y_{A_3}, \dots, y_{A_n}\}\)\(\{y_{B_1}, y_{B_2}, y_{B_3}, \dots, y_{B_n}\}\)

我们考虑优化这个过程。

接下来的过程中我们假设 \(n\)\(2\) 的整数次幂。

\[A(x) = a_0 + a_1x + a_2x^2 + a_3x^3 + \dots + a_{n-1}x^{n-1} \]

按照 \(x\) 次数的奇偶分类。

\[A(x) = (a_0 + a_2x^2 + \dots + a_{n-2}x^{n-2}) + x(a_1 + a_3x^2 + \dots + a_{n-1}x^{n-2}) \]

\[A_1(x) = (a_0 + a_2x + \dots + a_{n-2}x^{\frac{n-2}{2}}) \]

\[A_2(x) = (a_1 + a_3x + \dots + a_{n-1}x^{\frac{n-2}{2}}) \]

\(A(x) = A_1(x^2) + xA_2(x^2)\)

显然 \(A_1(x)\)\(A_2(x)\) 都为 \(\dfrac{n}{2}\) 次多项式。

接下来就是 FFT 的核心。

对于所有 \(k \leq \frac{n}{2}\),有

\[A(\omega^k_n) = A_1(\omega^{2k}_n ) + \omega^k_n A_2(\omega^{2k}_n) \\ A(\omega^k_n) = A_1(\omega^{k}_{\frac{n}{2}} ) + \omega^{k}_{\frac{n}{2}}A_2(\omega^{k}_{\frac{n}{2}}) \]

\[A(\omega^{k + \frac{n}{2}}_n) = A_1(\omega^{2k + n}_n) + \omega^{k + \frac{n}{2}}_n A_2(\omega^{2k + n}_n) \\ A(\omega^{k + \frac{n}{2}}_n) = A_1(\omega^{2k}_n ) - \omega^k_n A_2(\omega^{2k}_n) \\ A(\omega^{k + \frac{n}{2}}_n) = A_1(\omega^{k}_{\frac{n}{2}} ) - \omega^{k}_{\frac{n}{2}}A_2(\omega^{k}_{\frac{n}{2}}) \]

然后我们就发现了一个很神奇的东西,如果我们知道了 \(A_1(\omega^{k}_{\frac{n}{2}})\)\(A_2(\omega^{k}_{\frac{n}{2}})\) 我们就可以同时求出 \(A(\omega^k_n)\)\(A(\omega^{k + \frac{n}{2}}_n)\)

并且由于 \(A_1(x)\)\(A_2(x)\) 都为 \(\dfrac{n}{2}\) 次多项式,我们可以用相同的办法求出他们,并且每次将次数缩小到 \(\dfrac{1}{2}\)

当递归到 \(1\) 次时直接带入求值即可。

这就是一个类似于分治的复杂度了,不难证明这个过程是 \(\mathcal{O}(n \log n)\) 的。

不过我们仍然还有问题需要解决,上面我们只优化了求出对应的点坐标的复杂度,并没有优化求出系数的复杂度,如果我们用拉格朗日插值来求出系数这个算法的复杂度还是 \(\mathcal{O}(n^2)\) 的。

但这就是 DFT 的优越之处,它的特殊取值使得它可以同样在 \(\mathcal{O}(n \log n)\) 的复杂度内求出对应的系数。

IFFT

IFFT 即将点值变为系数的过程。

有个比较神奇的结论,记我们通过 FFT 算出来的结果的乘积为 \(\{y_{C_0}, y_{C_1}, y_{C_2}, \dots, y_{C_{n-1}}\}\)

记多项式 \(D(x) = y_{C_0} + y_{C_1}x + y_{C_2}x^2 + \dots + y_{C_{n-1}}x^{n - 1}\)\(\{ \omega^{-0}_n, \omega^{-1}_n, \omega^{-2}_n, \dots, \omega^{-(n-1)}_n \}\) 处的取值为 \(\{ d_0, d_1, d_2, \dots, d_{n-1} \}\)

\(c_i = \dfrac{d_i}{n}\)

\(C(x) = \dfrac{d_0}{n} + \dfrac{d_1}{n}x + \dfrac{d_2}{n}x^2 + \dots + \dfrac{d_{n-1}}{n}x^{n-1}\)

IFFT 证明

\[d_k = \sum^{n-1}_{i=0}{y_{C_i} (\omega^{-k}_n)^i} \\ d_k =\sum^{n-1}_{i=0}{(\sum^{n-1}_{j=0}{c_j (\omega^i_n)^j}) (\omega^{-k}_n)^i} \\ d_k = \sum^{n-1}_{i=0}{(\sum^{n-1}_{j=0}{c_j (\omega^j_n)^i}) (\omega^{-k}_n)^i} \\ d_k =\sum^{n-1}_{i=0}{\sum^{n-1}_{j=0}{c_j (\omega^{j-k}_n)^i}} \\ d_k =\sum^{n-1}_{j=0}{c_j \sum^{n-1}_{i=0}{(\omega^{j-k}_n)^i}} \\ \]

\(S(k) = \sum_{i=0}^{n-1}{(\omega^k_n)^i}\)

根据等比数列求和公式,首项为 \(1\),公比为 \(w^k_n\) ,当公比不为 \(0\),即 $k \neq 0 $ 时,有:

\[S(k) = \dfrac{1 - (\omega^k_n)^n}{1 - \omega^k_n} = \dfrac{1 - (\omega^n_n)^k}{1 - \omega^k_n} = \dfrac{1 - 1^k}{1 - \omega^k_n} = 0 \]

\(k = 0\) 时,有:

\[S(k) = \sum_{i=0}^{n-1}{1^i} = n \]

所以不难得出

\[d_k = \sum^{n-1}_{j=0}S(j-k) c_j \\ d_k = nc_k \\ c_k = \dfrac{d_k}{n} \]

FFT 未优化代码

#include <bits/stdc++.h>
using namespace std;

namespace Math {
	#define PI 3.14159265358979323846

	struct Complex { double r, i; Complex(double r = 0, double i = 0): r(r), i(i) { } };
	Complex operator + (const Complex &a, const Complex &b) { return { a.r + b.r, a.i + b.i }; }
	Complex operator - (const Complex &a, const Complex &b) { return { a.r - b.r, a.i - b.i }; }
	Complex operator * (const Complex &a, const Complex &b) { return { a.r * b.r - a.i * b.i, a.r * b.i + a.i * b.r }; }

	void FFT(int len, vector<Complex> &a, int op) {
		if (len == 1) return;
		vector<Complex> a1(len / 2), a2(len / 2);
		for (int i = 0; i < a.size(); i += 2) a1[i / 2] = a[i], a2[i / 2] = a[i + 1];
		FFT(len / 2, a1, op), FFT(len / 2, a2, op);
		Complex step = { cos(PI * 2 / len), op * sin(PI * 2 / len) }, omega = { 1, 0 };
		for (int i = 0; i < len / 2; i ++, omega = omega * step) a[i] = a1[i] + omega * a2[i], a[i + len / 2] = a1[i] - omega * a2[i]; 
	} 
	
	vector<int> polyMul(vector<int> a, vector<int> b) {
		int n = a.size(), m = b.size();
		int len = 1;
		while (len < n + m - 1) len *= 2;
		vector<Complex> ac(len), bc(len);
		Complex init = { 0, 0 };
		for (int i = 0; i < len; i ++) ac[i] = (i < n ? a[i] : init); 
		for (int i = 0; i < len; i ++) bc[i] = (i < m ? b[i] : init); 
		FFT(len, ac, 1), FFT(len, bc, 1); // FFT
		vector<Complex> cc(len);
		for (int i = 0; i < len; i ++) cc[i] = ac[i] * bc[i];
		FFT(len, cc, -1); // IFFT
		vector<int> c(n + m - 1);
		for (int i = 0; i < n + m - 1; i ++) c[i] = cc[i].r / len + 0.5;
		return c;
	}

	#undef PI
}

int main() {
	int n, m; scanf("%d%d", &n, &m);
	vector<int> a(n + 1), b(m + 1);
	for (int i = 0; i <= n; i ++) scanf("%d", &a[i]);
	for (int i = 0; i <= m; i ++) scanf("%d", &b[i]);
	vector<int> c = Math::polyMul(a, b);
	for (int i = 0; i < c.size(); i ++) printf("%d ", c[i]);
	return 0;
}

虽然看网上博客说分治版本的 FFT 没法通过洛谷的模板题,但我莫名其妙过了……

但是常数确实是比较大的。

image.png

FFT 优化

主要是可以优化掉递归的常数。

不难发现可以直接通过下表二进制反转后分组来规避掉递归的常数。(为什么?因为奇偶分组是用最低位来决定的,所以考虑这个序列最终的顺序就应该倒过来)

代码:

#include <bits/stdc++.h>
using namespace std;

namespace Math {
	#define PI 3.14159265358979323846

	struct Complex { double r, i; Complex(double r = 0, double i = 0): r(r), i(i) { } };
	Complex operator + (const Complex &a, const Complex &b) { return { a.r + b.r, a.i + b.i }; }
	Complex operator - (const Complex &a, const Complex &b) { return { a.r - b.r, a.i - b.i }; }
	Complex operator * (const Complex &a, const Complex &b) { return { a.r * b.r - a.i * b.i, a.r * b.i + a.i * b.r }; }

	vector<int> rev;

	void init_FFT(int len, int cnt) {
		rev.resize(len);
		for (int i = 0; i < len; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
	} 

	// fast implement
	void FFT(int len, vector<Complex> &a, int op) {
		for (int i = 0; i < len; i ++) if (i < rev[i]) swap(a[i], a[rev[i]]);
		for (int t = 1; t < len; t <<= 1) {
			Complex step = { cos(PI / t), op * sin(PI / t) };
			for (int i = 0; i < len; i += 2 * t) {
				Complex omega = { 1, 0 };
				for (int j = i; j < i + t; j ++, omega = omega * step) {
					Complex x = a[j], y = omega * a[j + t];
					a[j] = x + y, a[j + t] = x - y;
				}
			}
		}
	}
	
	vector<int> polyMul(vector<int> a, vector<int> b) {
		int n = a.size(), m = b.size();
		int len = 1, cnt = 0;
		while (len < n + m - 1) len *= 2, ++ cnt;
		init_FFT(len, cnt);
		vector<Complex> ac(len), bc(len);
		Complex init = { 0, 0 };
		for (int i = 0; i < len; i ++) ac[i] = (i < n ? a[i] : init); 
		for (int i = 0; i < len; i ++) bc[i] = (i < m ? b[i] : init); 
		FFT(len, ac, 1), FFT(len, bc, 1); // FFT
		vector<Complex> cc(len);
		for (int i = 0; i < len; i ++) cc[i] = ac[i] * bc[i];
		FFT(len, cc, -1); // IFFT
		vector<int> c(n + m - 1);
		for (int i = 0; i < n + m - 1; i ++) c[i] = cc[i].r / len + 0.5;
		return c;
	}

	#undef PI
}

int main() {
	int n, m; scanf("%d%d", &n, &m);
	vector<int> a(n + 1), b(m + 1);
	for (int i = 0; i <= n; i ++) scanf("%d", &a[i]);
	for (int i = 0; i <= m; i ++) scanf("%d", &b[i]);
	vector<int> c = Math::polyMul(a, b);
	for (int i = 0; i < c.size(); i ++) printf("%d ", c[i]);
	return 0;
}

优化掉了一秒,可喜可贺。

image.png

NTT

FFT 中由于要使用单位根,精度是一个很大的问题,所以就有了 NTT。

在下文中,\(n\)\(2\) 的整数次幂,\(p\) 为质数且 \(n \mid p - 1\)\(g\)\(p\) 的一个原根。

我们设 \(h^k_n = (g^{\frac{p-1}{n}})^k\)

那么 \(\{h^0_n, h^1_n, \dots, h^{n-1}_n\}\) 就有一些和单位根一样优秀的性质。

\[h^0_n = 1 \\ h^n_n = g^{p-1} \equiv 1 (定义) \\ h^{\frac{n}{2}}_n = g^{\frac{p-1}{2}} \equiv 1 (因为互异性所以不是 1 ) \\ h^{ak}_{an} = g^{\frac{p-1}{an}ak} = (g^{\frac{p-1}{n}})^k = h^k_n \]

可以发现在 FFT 中使用到的单位根的性质原根全部满足。

所以我们就可以使用原根来代替单位根了。

质数我们一般选择 \(998244353\) 因为 \(998244352\) 有因子 \(2^{22}\)

原根我们一般使用 \(3\),好记。

NTT 代码

#include <bits/stdc++.h>
using namespace std;

namespace Math {
	typedef long long ll;
	const ll MOD = 998244353;
	const ll G = 3;
	const ll GINV = 332748118;


	vector<int> rev;
	void init_NTT(int len, int cnt) {
		rev.resize(len);
		for (int i = 0; i < len; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
	} 

	ll qpow(ll a, ll b) {
		ll ret = 1;
		for (; b; b >>= 1) {
			if (b & 1) ret = ret * a % MOD;
			a = a * a % MOD;
		}
		return ret;
	}

	void NTT(int len, vector<int> &a, int op) {
		for (int i = 0; i < len; i ++) if (i < rev[i]) swap(a[i], a[rev[i]]);
		for (int t = 1; t < len; t <<= 1) {
			ll h = qpow(op ? G : GINV,  (MOD - 1) / (t * 2));
			for (int i = 0; i < len; i += 2 * t) {
				ll w = 1;
				for (int j = i; j < i + t; j ++, w = w * h % MOD) {
					ll x = a[j], y = w * a[j + t] % MOD;
					a[j] = (x + y) % MOD, a[j + t] = (x - y + MOD) % MOD;
				}
			}
		}
	}
	
	vector<int> polyMul(vector<int> a, vector<int> b) {
		int n = a.size(), m = b.size();
		int len = 1, cnt = 0;
		while (len < n + m - 1) len *= 2, ++ cnt;
		init_NTT(len, cnt);
		a.resize(len), b.resize(len);
		NTT(len, a, 1), NTT(len, b, 1);
		vector<int> c(len);
		for (int i = 0; i < len; i ++) c[i] = 1ll * a[i] * b[i] % MOD;
		NTT(len, c, 0);
		ll inv = qpow(len, MOD - 2);
		c.resize(n + m - 1);
		for (int i = 0; i < n + m - 1; i ++) c[i] = inv * c[i] % MOD;
		return c;
	}
}

int main() {
	int n, m; scanf("%d%d", &n, &m);
	vector<int> a(n + 1), b(m + 1);
	for (int i = 0; i <= n; i ++) scanf("%d", &a[i]);
	for (int i = 0; i <= m; i ++) scanf("%d", &b[i]);
	vector<int> c = Math::polyMul(a, b);
	for (int i = 0; i < c.size(); i ++) printf("%d ", c[i]);
	return 0;
}

image.png

确实快了不止一点。

完结撒花!

CREDIT

快速傅里叶变换(FFT)详解 - 自为风月马前卒 - 博客园 (cnblogs.com)

十分简明易懂的FFT(快速傅里叶变换)_路人黑的纸巾的博客-CSDN博客_fft