多项式乘法

给定两个多项式

A(x)=i=0n1aixi

B(x)=i=0m1bixi

C(x)=A(x)×B(x)=i=0n+m1cixi

前置知识

单位根

n 次单位根即为满足 xn=1 的 x。

由代数基本定理可得 n 次单位根应该有 n 个。

k 个记为 ωnk

单位根的一些性质

ωnk=cos2kπn+isin2kπn(k=0,1,,n1)ω2n2k=ωnkωnk+n2=ωnkωniωnj=ωni+j(ωnk)j=ωnjk

从大佬博客偷来的图

原根的性质(p 为质数)

i,jZ,gigj(modp)gp11(modp)

多项式的点表示法

众所周知,一个 n1 次多项式可以用 n 个互异的点表示出来。

这就叫做多项式的点表示法。

DFT

上面的点表示法启示我们了一种求多项式乘法的方法。

一个 n 次的多项式 A(x),对于取值 {x1,x2,x3,,xn},可以求出{yA1,yA2,yA3,,yAn}

同理一个 n 次的多项式 B(x),也可以求出相对应的 {yB1,yB2,yB3,,yBn}

然后我们对于每一个 xi 可以求出一个对应的 yCi=yAi×yBi

得到了 {yC1,yC2,yC3,,yCn},我们就可以用拉格朗日插值法求出多项式C(x)=A(x)×B(x) 了。

当我们将取值 {x1,x2,x3,,xn} 取为 {ωn0,ωn1,ωn2,,ωnn1} 时我们就可以少算很多的次幂。

利用上述方法求多项式乘法的方法就叫做 DFT

FFT

上面的 DFT 虽然看起来很妙,但仍然是 O(n2) 的,其主要复杂度为求出对应的 {yA1,yA2,yA3,,yAn}{yB1,yB2,yB3,,yBn}

我们考虑优化这个过程。

接下来的过程中我们假设 n2 的整数次幂。

A(x)=a0+a1x+a2x2+a3x3++an1xn1

按照 x 次数的奇偶分类。

A(x)=(a0+a2x2++an2xn2)+x(a1+a3x2++an1xn2)

A1(x)=(a0+a2x++an2xn22)

A2(x)=(a1+a3x++an1xn22)

A(x)=A1(x2)+xA2(x2)

显然 A1(x)A2(x) 都为 n2 次多项式。

接下来就是 FFT 的核心。

对于所有 kn2,有

A(ωnk)=A1(ωn2k)+ωnkA2(ωn2k)A(ωnk)=A1(ωn2k)+ωn2kA2(ωn2k)

A(ωnk+n2)=A1(ωn2k+n)+ωnk+n2A2(ωn2k+n)A(ωnk+n2)=A1(ωn2k)ωnkA2(ωn2k)A(ωnk+n2)=A1(ωn2k)ωn2kA2(ωn2k)

然后我们就发现了一个很神奇的东西,如果我们知道了 A1(ωn2k)A2(ωn2k) 我们就可以同时求出 A(ωnk)A(ωnk+n2)

并且由于 A1(x)A2(x) 都为 n2 次多项式,我们可以用相同的办法求出他们,并且每次将次数缩小到 12

当递归到 1 次时直接带入求值即可。

这就是一个类似于分治的复杂度了,不难证明这个过程是 O(nlogn) 的。

不过我们仍然还有问题需要解决,上面我们只优化了求出对应的点坐标的复杂度,并没有优化求出系数的复杂度,如果我们用拉格朗日插值来求出系数这个算法的复杂度还是 O(n2) 的。

但这就是 DFT 的优越之处,它的特殊取值使得它可以同样在 O(nlogn) 的复杂度内求出对应的系数。

IFFT

IFFT 即将点值变为系数的过程。

有个比较神奇的结论,记我们通过 FFT 算出来的结果的乘积为 {yC0,yC1,yC2,,yCn1}

记多项式 D(x)=yC0+yC1x+yC2x2++yCn1xn1{ωn0,ωn1,ωn2,,ωn(n1)} 处的取值为 {d0,d1,d2,,dn1}

ci=din

C(x)=d0n+d1nx+d2nx2++dn1nxn1

IFFT 证明

dk=i=0n1yCi(ωnk)idk=i=0n1(j=0n1cj(ωni)j)(ωnk)idk=i=0n1(j=0n1cj(ωnj)i)(ωnk)idk=i=0n1j=0n1cj(ωnjk)idk=j=0n1cji=0n1(ωnjk)i

S(k)=i=0n1(ωnk)i

根据等比数列求和公式,首项为 1,公比为 wnk ,当公比不为 0,即 k0 时,有:

S(k)=1(ωnk)n1ωnk=1(ωnn)k1ωnk=11k1ωnk=0

k=0 时,有:

S(k)=i=0n11i=n

所以不难得出

dk=j=0n1S(jk)cjdk=nckck=dkn

FFT 未优化代码

#include <bits/stdc++.h>
using namespace std;
namespace Math {
#define PI 3.14159265358979323846
struct Complex { double r, i; Complex(double r = 0, double i = 0): r(r), i(i) { } };
Complex operator + (const Complex &a, const Complex &b) { return { a.r + b.r, a.i + b.i }; }
Complex operator - (const Complex &a, const Complex &b) { return { a.r - b.r, a.i - b.i }; }
Complex operator * (const Complex &a, const Complex &b) { return { a.r * b.r - a.i * b.i, a.r * b.i + a.i * b.r }; }
void FFT(int len, vector<Complex> &a, int op) {
if (len == 1) return;
vector<Complex> a1(len / 2), a2(len / 2);
for (int i = 0; i < a.size(); i += 2) a1[i / 2] = a[i], a2[i / 2] = a[i + 1];
FFT(len / 2, a1, op), FFT(len / 2, a2, op);
Complex step = { cos(PI * 2 / len), op * sin(PI * 2 / len) }, omega = { 1, 0 };
for (int i = 0; i < len / 2; i ++, omega = omega * step) a[i] = a1[i] + omega * a2[i], a[i + len / 2] = a1[i] - omega * a2[i];
}
vector<int> polyMul(vector<int> a, vector<int> b) {
int n = a.size(), m = b.size();
int len = 1;
while (len < n + m - 1) len *= 2;
vector<Complex> ac(len), bc(len);
Complex init = { 0, 0 };
for (int i = 0; i < len; i ++) ac[i] = (i < n ? a[i] : init);
for (int i = 0; i < len; i ++) bc[i] = (i < m ? b[i] : init);
FFT(len, ac, 1), FFT(len, bc, 1); // FFT
vector<Complex> cc(len);
for (int i = 0; i < len; i ++) cc[i] = ac[i] * bc[i];
FFT(len, cc, -1); // IFFT
vector<int> c(n + m - 1);
for (int i = 0; i < n + m - 1; i ++) c[i] = cc[i].r / len + 0.5;
return c;
}
#undef PI
}
int main() {
int n, m; scanf("%d%d", &n, &m);
vector<int> a(n + 1), b(m + 1);
for (int i = 0; i <= n; i ++) scanf("%d", &a[i]);
for (int i = 0; i <= m; i ++) scanf("%d", &b[i]);
vector<int> c = Math::polyMul(a, b);
for (int i = 0; i < c.size(); i ++) printf("%d ", c[i]);
return 0;
}

虽然看网上博客说分治版本的 FFT 没法通过洛谷的模板题,但我莫名其妙过了……

但是常数确实是比较大的。

image.png

FFT 优化

主要是可以优化掉递归的常数。

不难发现可以直接通过下表二进制反转后分组来规避掉递归的常数。(为什么?因为奇偶分组是用最低位来决定的,所以考虑这个序列最终的顺序就应该倒过来)

代码:

#include <bits/stdc++.h>
using namespace std;
namespace Math {
#define PI 3.14159265358979323846
struct Complex { double r, i; Complex(double r = 0, double i = 0): r(r), i(i) { } };
Complex operator + (const Complex &a, const Complex &b) { return { a.r + b.r, a.i + b.i }; }
Complex operator - (const Complex &a, const Complex &b) { return { a.r - b.r, a.i - b.i }; }
Complex operator * (const Complex &a, const Complex &b) { return { a.r * b.r - a.i * b.i, a.r * b.i + a.i * b.r }; }
vector<int> rev;
void init_FFT(int len, int cnt) {
rev.resize(len);
for (int i = 0; i < len; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
}
// fast implement
void FFT(int len, vector<Complex> &a, int op) {
for (int i = 0; i < len; i ++) if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int t = 1; t < len; t <<= 1) {
Complex step = { cos(PI / t), op * sin(PI / t) };
for (int i = 0; i < len; i += 2 * t) {
Complex omega = { 1, 0 };
for (int j = i; j < i + t; j ++, omega = omega * step) {
Complex x = a[j], y = omega * a[j + t];
a[j] = x + y, a[j + t] = x - y;
}
}
}
}
vector<int> polyMul(vector<int> a, vector<int> b) {
int n = a.size(), m = b.size();
int len = 1, cnt = 0;
while (len < n + m - 1) len *= 2, ++ cnt;
init_FFT(len, cnt);
vector<Complex> ac(len), bc(len);
Complex init = { 0, 0 };
for (int i = 0; i < len; i ++) ac[i] = (i < n ? a[i] : init);
for (int i = 0; i < len; i ++) bc[i] = (i < m ? b[i] : init);
FFT(len, ac, 1), FFT(len, bc, 1); // FFT
vector<Complex> cc(len);
for (int i = 0; i < len; i ++) cc[i] = ac[i] * bc[i];
FFT(len, cc, -1); // IFFT
vector<int> c(n + m - 1);
for (int i = 0; i < n + m - 1; i ++) c[i] = cc[i].r / len + 0.5;
return c;
}
#undef PI
}
int main() {
int n, m; scanf("%d%d", &n, &m);
vector<int> a(n + 1), b(m + 1);
for (int i = 0; i <= n; i ++) scanf("%d", &a[i]);
for (int i = 0; i <= m; i ++) scanf("%d", &b[i]);
vector<int> c = Math::polyMul(a, b);
for (int i = 0; i < c.size(); i ++) printf("%d ", c[i]);
return 0;
}

优化掉了一秒,可喜可贺。

image.png

NTT

FFT 中由于要使用单位根,精度是一个很大的问题,所以就有了 NTT。

在下文中,n2 的整数次幂,p 为质数且 np1gp 的一个原根。

我们设 hnk=(gp1n)k

那么 {hn0,hn1,,hnn1} 就有一些和单位根一样优秀的性质。

hn0=1hnn=gp11()hnn2=gp121(1)hanak=gp1anak=(gp1n)k=hnk

可以发现在 FFT 中使用到的单位根的性质原根全部满足。

所以我们就可以使用原根来代替单位根了。

质数我们一般选择 998244353 因为 998244352 有因子 222

原根我们一般使用 3,好记。

NTT 代码

#include <bits/stdc++.h>
using namespace std;
namespace Math {
typedef long long ll;
const ll MOD = 998244353;
const ll G = 3;
const ll GINV = 332748118;
vector<int> rev;
void init_NTT(int len, int cnt) {
rev.resize(len);
for (int i = 0; i < len; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (cnt - 1));
}
ll qpow(ll a, ll b) {
ll ret = 1;
for (; b; b >>= 1) {
if (b & 1) ret = ret * a % MOD;
a = a * a % MOD;
}
return ret;
}
void NTT(int len, vector<int> &a, int op) {
for (int i = 0; i < len; i ++) if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int t = 1; t < len; t <<= 1) {
ll h = qpow(op ? G : GINV, (MOD - 1) / (t * 2));
for (int i = 0; i < len; i += 2 * t) {
ll w = 1;
for (int j = i; j < i + t; j ++, w = w * h % MOD) {
ll x = a[j], y = w * a[j + t] % MOD;
a[j] = (x + y) % MOD, a[j + t] = (x - y + MOD) % MOD;
}
}
}
}
vector<int> polyMul(vector<int> a, vector<int> b) {
int n = a.size(), m = b.size();
int len = 1, cnt = 0;
while (len < n + m - 1) len *= 2, ++ cnt;
init_NTT(len, cnt);
a.resize(len), b.resize(len);
NTT(len, a, 1), NTT(len, b, 1);
vector<int> c(len);
for (int i = 0; i < len; i ++) c[i] = 1ll * a[i] * b[i] % MOD;
NTT(len, c, 0);
ll inv = qpow(len, MOD - 2);
c.resize(n + m - 1);
for (int i = 0; i < n + m - 1; i ++) c[i] = inv * c[i] % MOD;
return c;
}
}
int main() {
int n, m; scanf("%d%d", &n, &m);
vector<int> a(n + 1), b(m + 1);
for (int i = 0; i <= n; i ++) scanf("%d", &a[i]);
for (int i = 0; i <= m; i ++) scanf("%d", &b[i]);
vector<int> c = Math::polyMul(a, b);
for (int i = 0; i < c.size(); i ++) printf("%d ", c[i]);
return 0;
}

image.png

确实快了不止一点。

完结撒花!

CREDIT

快速傅里叶变换(FFT)详解 - 自为风月马前卒 - 博客园 (cnblogs.com)

十分简明易懂的FFT(快速傅里叶变换)_路人黑的纸巾的博客-CSDN博客_fft