有关卷积的一些多项式变换
前置知识:单位根,原根,*CRT。
单位根
概念
在复数下,满足 \(x^n = 1\) 的 \(x\) 称为 \(n\) 次单位根。
\(n\) 次单位根一共有 \(n\) 个。
将所有的单位根按照辐角大小排列,第 \(k\) 个(\(0 \leq k < n\))个 \(n\) 次单位根为:
\(x_k = e^{i \frac{2 k \pi}{n}}\)
所有的单位根模都是 \(1\),\(n\) 个单位根平分单位圆。
本原单位根:\(0\) 到 \(n - 1\) 次方的值能生成所有 \(n\) 次单位根的单位根称为 \(n\) 次本原单位根。
欧拉公式:\(e^{i \pi} = 1\)
\(x_1 = e^{i \frac{2 \pi}{n}}\) 是一个本原单位根。
记 \(n\) 次本原单位根为 \(\omega_n = e^{i \frac{2 \pi}{n}} = \cos \frac{2 \pi}{n} + i \sin \frac{2 \pi}{n}\)
性质
设 \(n\) 是一个正偶数,则:
\((\omega_n^k)^2 = w_{\frac{n}{2}}^k\)
\(\omega_{n}^{\frac{n}{2} + k} = -\omega_n^k\)
FFT
求出一个 \(n\) 次多项式在每个 \(n\) 次单位根下的点值称为 离散傅里叶变换(DFT),而重新从这 \(n\) 个点值得到 \(n\) 次多项式的过程称为 离散傅里叶逆变换(IDFT)
对于要进行 FFT 的多项式,先将它的次数补到 \(2\) 的整次幂,记其次数为 \(n - 1\),令 \(m = \frac{n}{2}\)
DFT
考虑求一个长度为 \(n\) 的数列 \(b\),其中:
\(\sum\limits_{i = 0}^{n - 1} b_i = \sum\limits_{i = 0}^{n - 1} a_i \times \omega_n^{ik}\)
即这个数列的第 \(k\) 项是原多项式在 \(n\) 次单位根的 \(k\) 次幂处的取值。
FFT
FFT 是对 \(O(n^2)\) 的朴素 DFT 的优化。
考虑把上面的和式按照下标分类,即:
\(A(x) = \sum\limits_{i = 0}^{n - 1} a_i \cdot x_i = \sum\limits_{i = 0}^{m - 1} a_{2i} \cdot x^{2i} + \sum\limits_{i = 0}^{m - 1} a_{2i + 1} \cdot x^{2i + 1}\)
后半部分提出一个 \(x\),得:
\(A(x) = \sum\limits_{i = 0}^{n - 1} a_i \cdot x_i = \sum\limits_{i = 0}^{m - 1} a_{2i} \cdot x^{2i} + x \sum\limits_{i = 0}^{m - 1} a_{2i + 1} \cdot x^{2i} = \sum\limits_{i = 0}^{m - 1} a_{2i} \cdot (x^2)^i + x \sum\limits_{i = 0}^{m - 1} a_{2i + 1} \cdot (x^2)^i\)
设 \(A_0(x), A_1(x)\) 是两个 \(m - 1\) 次多项式,使得 \(A_0(x) = \sum\limits_{i = 0}^{m - 1} a_{2i} x^i, A_1(x) = \sum\limits_{i = 0}^{m - 1} a_{2i + 1} x^i\)
那么 \(A(x) = A_0(x^2) + x A_1(x^2)\)
这说明只要求出 \(A_0, A_1\) 在各处的点值,我们就可以在 \(O(n)\) 时间内计算 \(A\) 在各处的点值!并且子问题的结构是类似地,完全可以递归处理!
但是这样递归下去的复杂度是 \(O(n^2)\)
所以我们可以考虑一下单位根的性质:
\((\omega_n^k)^2 = \omega_m^k\)
\(\omega_{n}^{m + k} = -\omega_n^k\)
考虑 \(m\) 次以内的点值和高于 \(m\) 次的点值之间的关系:
\(\forall 0 \leq k < m, A(\omega_n^k) = A_0((\omega_n^k)^2) + \omega_n^k A_1((\omega_n^k)^2)\)
根据第一个式子化简得到:
\(A(\omega_n^k) = A_0(\omega_m^k) + \omega_n^k A_1(\omega_m^k)\)
对于大于等于 \(m\) 次的点值:
\(A(\omega_n^{m + k}) = A_0((\omega_n^{m + k})^2) + \omega_n^{m + k} A_1((\omega_n^{m + k})^2)\)
根据第二个式子化简得到:
\(A(\omega_n^{m + k}) = A_0((\omega_n^k)^2) - \omega_n^k A_1((\omega_n^k)^2)\)
再根据第一个式子得到:
\(A(\omega_n^{m + k}) = A_0(\omega_m^k) - \omega_n^k A_1(\omega_m^k)\)
比对一下:\(A(\omega_n^k) = A_0(\omega_m^k) + \omega_n^k A_1(\omega_m^k)\)
差别只有后半部分的系数!!!1
所以我们只需要把小于 \(m\) 次的部分递归处理出来,然后再处理后半部分即可。
时间复杂度优化到 \(O(n \log n)\)
上面的推导称为 蝴蝶操作。
IDFT
不太会证明,暂且先放个结论,等以后再回来填坑。
假设当前已知一个 \(n - 1\) 次多项式 \(A(x) = \sum\limits_{i = 0}^{n - 1} a_i \cdot x^i\) 经过 DFT 后得到的点值序列 \(b\),则:
\(b_k = \sum\limits_{i = 0}^{n - 1} a_i \cdot \omega_n^{ik}\)
结论是:\(a_k = \frac{1}{n} \sum\limits_{i = 0}^{n - 1} b_i \omega_n^{-ki}\)
这里有一个消去引理的知识,之后再补回来。
IFFT
如何快速求 \(B\) 在 \(\omega_n^{-ki}\) 处的取值?
根据几何意义易知:\(\omega_n^{-k} = \omega_n^{n - k}\)
所以只需要正常求 \(B\) 在 \(\omega_n^k\) 处的取值,然后对数组取反即可,magic!
优化
递归改成迭代会快很多。
考虑画一下递归调用的原数组下标:
stage 1: 0 1 2 3 4 5 6 7
stage 2: 0 2 4 6|1 3 5 7
stage 3: 0 4|2 6|1 5|3 7
stage 4: 0|4|2|6|1|5|3|7
人类智慧:stage 4 二进制写法反过来刚好是:
000, 001, 010, 011, 100, 110, 101, 111
也就是 01234567,magic!
所以只需要把原本的 \(a\) 数组按照二进制下标反转后的大小排序,倒数第二层的蝴蝶操作就是合并 \(a\) 中相邻的项。
用类似数位 dp 的方法可以快速求出每个值二进制反转后的值:
rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0));
因为 \(rev(rev(i)) = i\),所以对于一个下标 \(p\),若 \(rev(p) = q\),那么 \(rev(q) = p\)。所以只需要对于 \(rev(i) > i\) 的位置交换一下就能排序力,这个称为位逆序置换。
三次变两次
卡常小寄巧。
考虑到朴素卷两个多项式需要三次或者四次 FFT,但是注意到:
\((A(x) + B(x)i)^2 = A^2(x) - B^2(x) + 2 A(x) B(x) i\).
所以考虑直接将 \(A(x) + B(x)i\) 自己和自己卷起来,然后虚部系数除以 2 就是答案了。
代码
#include <cstdio>
#include <cmath>
#include <complex>
#include <iostream>
#include <algorithm>
using namespace std;
const int sz = 5e6 + 5;
const double PI = acos(-1);
int n, m;
int rev[sz];
complex<double> F[sz], G[sz];
void calc_rev(int k) { for (int i = 0; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }
void FFT(complex<double> *A, int n)
{
for (int i = 1; i < n; i++)
if (rev[i] > i) swap(A[i], A[rev[i]]);
for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
{
complex<double> W(cos(PI / m), sin(PI / m)), w(1.0, 0.0);
for (int l = 0, r = len - 1; r <= n; l += len, r += len)
{
auto w0 = w;
for (int p = l; p < l + m; p++)
{
auto x = A[p] + w0 * A[p + m], y = A[p] - w0 * A[p + m];
A[p] = x, A[p + m] = y;
w0 *= W;
}
}
}
}
void IFFT(complex<double> *A, int n)
{
FFT(A, n);
reverse(A + 1, A + n);
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 0, v; i <= n; i++) scanf("%d", &v), F[i] = v;
for (int i = 0, v; i <= m; i++) scanf("%d", &v), G[i] = v;
int len = n + m, k = 1;
while (k <= len) k <<= 1;
calc_rev(k);
FFT(F, k), FFT(G, k);
for (int i = 0; i < k; i++) F[i] *= G[i];
IFFT(F, k);
for (int i = 0; i <= len; i++) printf("%d ", (int)(F[i].real() / k + 0.5));
return 0;
}
NTT
NTT,快速数论变换,对 FFT 的进一步优化。
通过在模意义下用原根代替单位根,避免了浮点数运算。
原根和阶
首先根据欧拉定理可知,当 \(a, m\) 互质时,\(a^{\varphi(m)} \equiv 1 \pmod m\).
因此 \(a_1, a_2, ...\) 一定存在长度为 \(\varphi(m)\) 的循环节,但它不一定是最小循环节。由此,定义 \(a\) 在模 \(m\) 意义下的阶是同余方程 \(a^x \equiv 1 \pmod m\) 的最小正整数解,记作 \(\operatorname{ord}_m a\).
显然 \(\operatorname{ord}_m a\) 一定是 \(\varphi(m)\) 的因数。特别地,当 \(\operatorname{ord}_m a = \varphi(m)\) 时,称 \(a\) 是模 \(m\) 意义下的一个原根。
显然原根 \(a\) 满足 \(a^1, ..., a^{\varphi(m)}\) 在模 \(m\) 意义下各不相同,符合我们要求代入点值各不相同的要求。
因为原根是整数,所以折半引理一类的东西对于原根也成立。
那么就可以考虑用原根代替单位根代入点值了。
有仙人证明过当 \(m\) 是质数时一定存在原根 \(g\),所以令 \(m\) 取质数。
因为原根 \(g\) 是模 \(m\) 意义下的 \(m - 1\) 次单位根,所以当模 \(m\) 意义下的 \(n\) 次单位根 \(\omega_n = g^{\frac{m - 1}{n}}\) 存在时,有 \(n \mid m - 1\).
因为分治时会把多项式的次数补到 \(2\) 的整次幂,所以考虑令 \(m - 1\) 含有较多的质因子 \(2\),从而得到 NTT 的常用模数 \(998244353\).
\(998244353, 167772161, 104857601\) 都是合法模数,并且原根都为 \(3\).
代码
void calc_rev(int k) { for (int i = 1; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }
void NTT(ll *A, int n)
{
calc_rev(n);
for (int i = 1; i < n; i++)
if (rev[i] > i) swap(A[i], A[rev[i]]);
for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
{
ll wn = qpow(g, (mod - 1) / len);
wp[0] = 1;
for (int i = 1; i <= len; i++) wp[i] = wp[i - 1] * wn % mod;
for (int l = 0, r = len - 1; r <= n; l += len, r += len)
{
int w = 0;
for (int p = l; p < l + m; p++, w++)
{
ll x = A[p], y = wp[w] * A[p + m] % mod;
A[p] = (x + y) % mod, A[p + m] = (x - y + mod) % mod;
}
}
}
}
void INTT(ll *A, int n)
{
NTT(A, n);
reverse(A + 1, A + n);
int inv = qpow(n, mod - 2);
for (int i = 0; i < n; i++) A[i] = 1ll * A[i] * inv % mod;
}
拆系数 FFT
这个做法适用于模数为任何数的情况,但是需要 4 次 FFT.
理论上常数是比 CRT 合并做法要小,但是 Prean 似乎搞出了跑的很快的 5 次 FFT 做法,可以看他的题解。
思路是考虑把多项式 \(A(x), B(x)\) 拆成四个多项式 \(A_0(x), A_1(x), B_0(x), B_1(x)\).
咕咕咕,有时间再更。
#include <cstdio>
#include <cmath>
#include <complex>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
const int sz = 4e5 + 5;
// const double PI = acos(-1);
const complex<double> I(0, 1);
int n, m, mod;
int rev[sz];
complex<double> a0[sz], b0[sz], a1[sz], b1[sz];
complex<double> p[sz], q[sz], wn[sz];
void calc_rev(int k) { for (int i = 0; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }
void FFT(complex<double> *A, int n)
{
calc_rev(n);
for (int i = 1; i < n; i++)
if (rev[i] > i) swap(A[i], A[rev[i]]);
for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
{
for (int l = 0, r = len - 1; r <= n; l += len, r += len)
{
for (int p = l; p < l + m; p++)
{
complex<double> w = wn[1ll * (p - l) * n / m];
complex<double> a0 = A[p], a1 = A[p + m] * w;
A[p] = a0 + a1, A[p + m] = a0 - a1;
}
}
}
}
void IFFT(complex<double> *A, int n)
{
FFT(A, n);
reverse(A + 1, A + n);
for (int i = 0; i < n; i++) A[i] /= n;
}
void fft(complex<double> *A, complex<double> *B, int len)
{
for (int i = 0; i < len; i++) A[i] = A[i] + I * B[i];
FFT(A, len);
for (int i = 0; i < len; i++) B[i] = conj(A[i ? len - i : 0]);
for (int i = 0; i < len; i++)
{
complex<double> p = A[i], q = B[i];
A[i] = (p + q) * 0.5, B[i] = (q - p) * 0.5 * I;
}
}
ll num(complex<double> x)
{
double d = x.real();
return (d < 0 ? ll(d - 0.5) % mod : ll(d + 0.5) % mod);
}
int main()
{
scanf("%d%d%d", &n, &m, &mod);
int sqn = int(sqrt(mod) + 1);
for (int i = 0, w; i <= n; i++)
{
scanf("%d", &w), w %= mod;
a0[i] = w / sqn, a1[i] = w % sqn;
}
for (int i = 0, w; i <= m; i++)
{
scanf("%d", &w), w %= mod;
b0[i] = w / sqn, b1[i] = w % sqn;
}
int len = 1;
while (len < (n + m + 1)) len <<= 1;
for (int i = 0; i < len; i++) wn[i] = complex<double>(cos(M_PI / len * i), sin(M_PI / len * i));
fft(a0, a1, len), fft(b0, b1, len);
// for (int i = 0; i < len; i++) printf("%lf %lf %lf %lf\n", a0[i].real(), a0[i].imag(), a1[i].real(), a1[i].imag());
// for (int i = 0; i < len; i++) printf("%lf %lf %lf %lf\n", b0[i].real(), b0[i].imag(), b1[i].real(), b1[i].imag());
for (int i = 0; i < len; i++)
{
p[i] = a0[i] * b0[i] + I * a1[i] * b0[i];
q[i] = a0[i] * b1[i] + I * a1[i] * b1[i];
}
IFFT(p, len), IFFT(q, len);
for (int i = 0; i <= n + m; i++) printf("%lld ", (sqn * sqn * num(p[i].real()) % mod + sqn * (num(p[i].imag()) + num(q[i].real())) % mod + num(q[i].imag())) % mod);
return 0;
}
分治 FFT
其实是分治 NTT.
对于给定的 \(n\) 次多项式 \(g\),求多项式 \(f\) 使得 \(f[i] = \sum\limits_{j = 0}^i f[i - j] g[j]\).
其实也就是求一个无标号计数(背包)。
考虑对于 \(f\) 的指数进行类似 cdq 的分治。换言之,对于指数在 \([l, r]\) 之间的项,考虑 \([l, \lfloor \frac{l + r}{2} \rfloor]\) 之间的项对于 \([\lfloor \frac{l + r}{2} \rfloor + 1, r]\) 之间的项的贡献。
类似背包的思路,每次将 \([l, \lfloor \frac{l + r}{2} \rfloor]\) 之间的项与 \(G\) 的第 \([0, r - l]\) 项卷起来,得到的多项式的第 \(k\) 项即为左半区间对第 \(l + k\) 项的贡献。
按照中序分治整个序列,每层的复杂度为卷积 \(O(n \log n)\),共分治 \(O(\log n)\) 层,总复杂度为 \(O(n \log^2 n)\).
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long ll;
const int lg_sz = 18;
const int maxn = (1 << lg_sz) + 1;
const int mod = 998244353;
const int g = 3;
int n;
int rev[maxn];
ll wp[maxn], F[maxn], G[maxn], Ft[maxn], Gt[maxn];
void calc_rev(int k) { for (int i = 1; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }
ll qpow(ll base, ll power, ll mod)
{
ll res = 1;
while (power)
{
if (power & 1) res = res * base % mod;
base = base * base % mod;
power >>= 1;
}
return res;
}
void NTT(ll *A, int n)
{
calc_rev(n);
for (int i = 1; i < n; i++)
if (rev[i] > i) swap(A[i], A[rev[i]]);
for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
{
ll wn = qpow(g, (mod - 1) / len, mod);
wp[0] = 1;
for (int i = 1; i <= len; i++) wp[i] = wp[i - 1] * wn % mod;
for (int l = 0, r = len - 1; r <= n; l += len, r += len)
{
int w = 0;
for (int p = l; p < l + m; p++, w++)
{
ll x = A[p], y = wp[w] * A[p + m] % mod;
A[p] = (x + y) % mod, A[p + m] = (x - y + mod) % mod;
}
}
}
}
void INTT(ll *A, int n)
{
NTT(A, n);
reverse(A + 1, A + n);
int inv = qpow(n, mod - 2, mod);
for (int i = 0; i < n; i++) A[i] = 1ll * A[i] * inv % mod;
}
void solve(int l, int r, int lg)
{
if ((!lg) || (l >= n)) return;
int mid = (l + r) >> 1;
solve(l, mid, lg - 1);
memset(Ft + (r - l) / 2, 0, (r - l) * sizeof(ll));
memcpy(Ft, F + l, (r - l) / 2 * sizeof(ll));
memcpy(Gt, G, (r - l) * sizeof(ll));
NTT(Ft, (1 << lg)), NTT(Gt, (1 << lg));
for (int i = 0; i < r - l; i++) Ft[i] = Ft[i] * Gt[i] % mod;
INTT(Ft, (1 << lg));
for (int i = (r - l) / 2; i < r - l; i++) F[l + i] = (F[l + i] + Ft[i]) % mod;
solve(mid, r, lg - 1);
}
int main()
{
scanf("%d", &n);
int lg = 0;
while ((1 << lg) < n) lg++;
for (int i = 1; i < n; i++) scanf("%lld", &G[i]);
F[0] = 1;
solve(0, (1 << lg), lg);
for (int i = 0; i < n; i++) printf("%lld ", F[i]); putchar('\n');
return 0;
}