Loading

有关卷积的一些多项式变换

前置知识:单位根,原根,*CRT。

单位根

概念

在复数下,满足 \(x^n = 1\)\(x\) 称为 \(n\) 次单位根。

\(n\) 次单位根一共有 \(n\) 个。

将所有的单位根按照辐角大小排列,第 \(k\) 个(\(0 \leq k < n\))个 \(n\) 次单位根为:

\(x_k = e^{i \frac{2 k \pi}{n}}\)

所有的单位根模都是 \(1\)\(n\) 个单位根平分单位圆。

本原单位根:\(0\)\(n - 1\) 次方的值能生成所有 \(n\) 次单位根的单位根称为 \(n\) 次本原单位根。

欧拉公式:\(e^{i \pi} = 1\)

\(x_1 = e^{i \frac{2 \pi}{n}}\) 是一个本原单位根。

\(n\) 次本原单位根为 \(\omega_n = e^{i \frac{2 \pi}{n}} = \cos \frac{2 \pi}{n} + i \sin \frac{2 \pi}{n}\)

性质

\(n\) 是一个正偶数,则:

\((\omega_n^k)^2 = w_{\frac{n}{2}}^k\)

\(\omega_{n}^{\frac{n}{2} + k} = -\omega_n^k\)

FFT

求出一个 \(n\) 次多项式在每个 \(n\) 次单位根下的点值称为 离散傅里叶变换(DFT),而重新从这 \(n\) 个点值得到 \(n\) 次多项式的过程称为 离散傅里叶逆变换(IDFT)

对于要进行 FFT 的多项式,先将它的次数补到 \(2\) 的整次幂,记其次数为 \(n - 1\),令 \(m = \frac{n}{2}\)

DFT

考虑求一个长度为 \(n\) 的数列 \(b\),其中:

\(\sum\limits_{i = 0}^{n - 1} b_i = \sum\limits_{i = 0}^{n - 1} a_i \times \omega_n^{ik}\)

即这个数列的第 \(k\) 项是原多项式在 \(n\) 次单位根的 \(k\) 次幂处的取值。

FFT

FFT 是对 \(O(n^2)\) 的朴素 DFT 的优化。

考虑把上面的和式按照下标分类,即:

\(A(x) = \sum\limits_{i = 0}^{n - 1} a_i \cdot x_i = \sum\limits_{i = 0}^{m - 1} a_{2i} \cdot x^{2i} + \sum\limits_{i = 0}^{m - 1} a_{2i + 1} \cdot x^{2i + 1}\)

后半部分提出一个 \(x\),得:

\(A(x) = \sum\limits_{i = 0}^{n - 1} a_i \cdot x_i = \sum\limits_{i = 0}^{m - 1} a_{2i} \cdot x^{2i} + x \sum\limits_{i = 0}^{m - 1} a_{2i + 1} \cdot x^{2i} = \sum\limits_{i = 0}^{m - 1} a_{2i} \cdot (x^2)^i + x \sum\limits_{i = 0}^{m - 1} a_{2i + 1} \cdot (x^2)^i\)

\(A_0(x), A_1(x)\) 是两个 \(m - 1\) 次多项式,使得 \(A_0(x) = \sum\limits_{i = 0}^{m - 1} a_{2i} x^i, A_1(x) = \sum\limits_{i = 0}^{m - 1} a_{2i + 1} x^i\)

那么 \(A(x) = A_0(x^2) + x A_1(x^2)\)

这说明只要求出 \(A_0, A_1\) 在各处的点值,我们就可以在 \(O(n)\) 时间内计算 \(A\) 在各处的点值!并且子问题的结构是类似地,完全可以递归处理!

但是这样递归下去的复杂度是 \(O(n^2)\)

所以我们可以考虑一下单位根的性质:

\((\omega_n^k)^2 = \omega_m^k\)

\(\omega_{n}^{m + k} = -\omega_n^k\)

考虑 \(m\) 次以内的点值和高于 \(m\) 次的点值之间的关系:

\(\forall 0 \leq k < m, A(\omega_n^k) = A_0((\omega_n^k)^2) + \omega_n^k A_1((\omega_n^k)^2)\)

根据第一个式子化简得到:

\(A(\omega_n^k) = A_0(\omega_m^k) + \omega_n^k A_1(\omega_m^k)\)

对于大于等于 \(m\) 次的点值:

\(A(\omega_n^{m + k}) = A_0((\omega_n^{m + k})^2) + \omega_n^{m + k} A_1((\omega_n^{m + k})^2)\)

根据第二个式子化简得到:

\(A(\omega_n^{m + k}) = A_0((\omega_n^k)^2) - \omega_n^k A_1((\omega_n^k)^2)\)

再根据第一个式子得到:

\(A(\omega_n^{m + k}) = A_0(\omega_m^k) - \omega_n^k A_1(\omega_m^k)\)

比对一下:\(A(\omega_n^k) = A_0(\omega_m^k) + \omega_n^k A_1(\omega_m^k)\)

差别只有后半部分的系数!!!1

所以我们只需要把小于 \(m\) 次的部分递归处理出来,然后再处理后半部分即可。

时间复杂度优化到 \(O(n \log n)\)

上面的推导称为 蝴蝶操作

IDFT

不太会证明,暂且先放个结论,等以后再回来填坑。

假设当前已知一个 \(n - 1\) 次多项式 \(A(x) = \sum\limits_{i = 0}^{n - 1} a_i \cdot x^i\) 经过 DFT 后得到的点值序列 \(b\),则:

\(b_k = \sum\limits_{i = 0}^{n - 1} a_i \cdot \omega_n^{ik}\)

结论是:\(a_k = \frac{1}{n} \sum\limits_{i = 0}^{n - 1} b_i \omega_n^{-ki}\)

这里有一个消去引理的知识,之后再补回来。

IFFT

如何快速求 \(B\)\(\omega_n^{-ki}\) 处的取值?

根据几何意义易知:\(\omega_n^{-k} = \omega_n^{n - k}\)

所以只需要正常求 \(B\)\(\omega_n^k\) 处的取值,然后对数组取反即可,magic!

优化

递归改成迭代会快很多。

考虑画一下递归调用的原数组下标:

stage 1: 0 1 2 3 4 5 6 7
stage 2: 0 2 4 6|1 3 5 7
stage 3: 0 4|2 6|1 5|3 7
stage 4: 0|4|2|6|1|5|3|7

人类智慧:stage 4 二进制写法反过来刚好是:

000, 001, 010, 011, 100, 110, 101, 111

也就是 01234567,magic!

所以只需要把原本的 \(a\) 数组按照二进制下标反转后的大小排序,倒数第二层的蝴蝶操作就是合并 \(a\) 中相邻的项。

用类似数位 dp 的方法可以快速求出每个值二进制反转后的值:

rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0));

因为 \(rev(rev(i)) = i\),所以对于一个下标 \(p\),若 \(rev(p) = q\),那么 \(rev(q) = p\)。所以只需要对于 \(rev(i) > i\) 的位置交换一下就能排序力,这个称为位逆序置换。

三次变两次

卡常小寄巧。

考虑到朴素卷两个多项式需要三次或者四次 FFT,但是注意到:

\((A(x) + B(x)i)^2 = A^2(x) - B^2(x) + 2 A(x) B(x) i\).

所以考虑直接将 \(A(x) + B(x)i\) 自己和自己卷起来,然后虚部系数除以 2 就是答案了。

代码

#include <cstdio>
#include <cmath>
#include <complex>
#include <iostream>
#include <algorithm>
using namespace std;

const int sz = 5e6 + 5;
const double PI = acos(-1);

int n, m;
int rev[sz];
complex<double> F[sz], G[sz];

void calc_rev(int k) { for (int i = 0; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }

void FFT(complex<double> *A, int n)
{
    for (int i = 1; i < n; i++)
        if (rev[i] > i) swap(A[i], A[rev[i]]);
    for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
    {
        complex<double> W(cos(PI / m), sin(PI / m)), w(1.0, 0.0);
        for (int l = 0, r = len - 1; r <= n; l += len, r += len)
        {
            auto w0 = w;
            for (int p = l; p < l + m; p++)
            {
                auto x = A[p] + w0 * A[p + m], y = A[p] - w0 * A[p + m];
                A[p] = x, A[p + m] = y;
                w0 *= W;
            }
        }
    }
}

void IFFT(complex<double> *A, int n)
{
    FFT(A, n);
    reverse(A + 1, A + n);
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 0, v; i <= n; i++) scanf("%d", &v), F[i] = v;
    for (int i = 0, v; i <= m; i++) scanf("%d", &v), G[i] = v;
    int len = n + m, k = 1;
    while (k <= len) k <<= 1;
    calc_rev(k);
    FFT(F, k), FFT(G, k);
    for (int i = 0; i < k; i++) F[i] *= G[i];
    IFFT(F, k);
    for (int i = 0; i <= len; i++) printf("%d ", (int)(F[i].real() / k + 0.5));
    return 0;
}

NTT

NTT,快速数论变换,对 FFT 的进一步优化。

通过在模意义下用原根代替单位根,避免了浮点数运算。

原根和阶

首先根据欧拉定理可知,当 \(a, m\) 互质时,\(a^{\varphi(m)} \equiv 1 \pmod m\).

因此 \(a_1, a_2, ...\) 一定存在长度为 \(\varphi(m)\) 的循环节,但它不一定是最小循环节。由此,定义 \(a\) 在模 \(m\) 意义下的阶是同余方程 \(a^x \equiv 1 \pmod m\) 的最小正整数解,记作 \(\operatorname{ord}_m a\).

显然 \(\operatorname{ord}_m a\) 一定是 \(\varphi(m)\) 的因数。特别地,当 \(\operatorname{ord}_m a = \varphi(m)\) 时,称 \(a\) 是模 \(m\) 意义下的一个原根。

显然原根 \(a\) 满足 \(a^1, ..., a^{\varphi(m)}\) 在模 \(m\) 意义下各不相同,符合我们要求代入点值各不相同的要求。

因为原根是整数,所以折半引理一类的东西对于原根也成立。

那么就可以考虑用原根代替单位根代入点值了。

有仙人证明过当 \(m\) 是质数时一定存在原根 \(g\),所以令 \(m\) 取质数。

因为原根 \(g\) 是模 \(m\) 意义下的 \(m - 1\) 次单位根,所以当模 \(m\) 意义下的 \(n\) 次单位根 \(\omega_n = g^{\frac{m - 1}{n}}\) 存在时,有 \(n \mid m - 1\).

因为分治时会把多项式的次数补到 \(2\) 的整次幂,所以考虑令 \(m - 1\) 含有较多的质因子 \(2\),从而得到 NTT 的常用模数 \(998244353\).

\(998244353, 167772161, 104857601\) 都是合法模数,并且原根都为 \(3\).

代码

void calc_rev(int k) { for (int i = 1; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }

void NTT(ll *A, int n)
{
    calc_rev(n);
    for (int i = 1; i < n; i++)
        if (rev[i] > i) swap(A[i], A[rev[i]]);
    for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
    {
        ll wn = qpow(g, (mod - 1) / len);
        wp[0] = 1;
        for (int i = 1; i <= len; i++) wp[i] = wp[i - 1] * wn % mod;
        for (int l = 0, r = len - 1; r <= n; l += len, r += len)
        {
            int w = 0;
            for (int p = l; p < l + m; p++, w++)
            {
                ll x = A[p], y = wp[w] * A[p + m] % mod;
                A[p] = (x + y) % mod, A[p + m] = (x - y + mod) % mod;
            }
        }
    }
}

void INTT(ll *A, int n)
{
    NTT(A, n);
    reverse(A + 1, A + n);
    int inv = qpow(n, mod - 2);
    for (int i = 0; i < n; i++) A[i] = 1ll * A[i] * inv % mod;
}

拆系数 FFT

这个做法适用于模数为任何数的情况,但是需要 4 次 FFT.

理论上常数是比 CRT 合并做法要小,但是 Prean 似乎搞出了跑的很快的 5 次 FFT 做法,可以看他的题解。

思路是考虑把多项式 \(A(x), B(x)\) 拆成四个多项式 \(A_0(x), A_1(x), B_0(x), B_1(x)\).

咕咕咕,有时间再更。

#include <cstdio>
#include <cmath>
#include <complex>
#include <iostream>
#include <algorithm>
using namespace std;

typedef long long ll;

const int sz = 4e5 + 5;
// const double PI = acos(-1);
const complex<double> I(0, 1);

int n, m, mod;
int rev[sz];
complex<double> a0[sz], b0[sz], a1[sz], b1[sz];
complex<double> p[sz], q[sz], wn[sz];

void calc_rev(int k) { for (int i = 0; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }

void FFT(complex<double> *A, int n)
{
    calc_rev(n);
    for (int i = 1; i < n; i++)
        if (rev[i] > i) swap(A[i], A[rev[i]]);
    for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
    {
        for (int l = 0, r = len - 1; r <= n; l += len, r += len)
        {
            for (int p = l; p < l + m; p++)
            {
                complex<double> w = wn[1ll * (p - l) * n / m];
                complex<double> a0 = A[p], a1 = A[p + m] * w;
                A[p] = a0 + a1, A[p + m] = a0 - a1;
            }
        }
    }
}

void IFFT(complex<double> *A, int n)
{
    FFT(A, n);
    reverse(A + 1, A + n);
    for (int i = 0; i < n; i++) A[i] /= n;
}

void fft(complex<double> *A, complex<double> *B, int len)
{
    for (int i = 0; i < len; i++) A[i] = A[i] + I * B[i];
    FFT(A, len);
    for (int i = 0; i < len; i++) B[i] = conj(A[i ? len - i : 0]);
    for (int i = 0; i < len; i++)
    {
        complex<double> p = A[i], q = B[i];
        A[i] = (p + q) * 0.5, B[i] = (q - p) * 0.5 * I;
    }
}

ll num(complex<double> x)
{
    double d = x.real();
    return (d < 0 ? ll(d - 0.5) % mod : ll(d + 0.5) % mod);
}

int main()
{
    scanf("%d%d%d", &n, &m, &mod);
    int sqn = int(sqrt(mod) + 1);
    for (int i = 0, w; i <= n; i++)
    {
        scanf("%d", &w), w %= mod;
        a0[i] = w / sqn, a1[i] = w % sqn;
    }
    for (int i = 0, w; i <= m; i++)
    {
        scanf("%d", &w), w %= mod;
        b0[i] = w / sqn, b1[i] = w % sqn;
    }
    int len = 1;
    while (len < (n + m + 1)) len <<= 1;
    for (int i = 0; i < len; i++) wn[i] = complex<double>(cos(M_PI / len * i), sin(M_PI / len * i));
    fft(a0, a1, len), fft(b0, b1, len);
    // for (int i = 0; i < len; i++) printf("%lf %lf %lf %lf\n", a0[i].real(), a0[i].imag(), a1[i].real(), a1[i].imag());
    // for (int i = 0; i < len; i++) printf("%lf %lf %lf %lf\n", b0[i].real(), b0[i].imag(), b1[i].real(), b1[i].imag());
    for (int i = 0; i < len; i++)
    {
        p[i] = a0[i] * b0[i] + I * a1[i] * b0[i];
        q[i] = a0[i] * b1[i] + I * a1[i] * b1[i];
    }
    IFFT(p, len), IFFT(q, len);
    for (int i = 0; i <= n + m; i++) printf("%lld ", (sqn * sqn * num(p[i].real()) % mod + sqn * (num(p[i].imag()) + num(q[i].real())) % mod + num(q[i].imag())) % mod); 
    return 0;
}

分治 FFT

其实是分治 NTT.

对于给定的 \(n\) 次多项式 \(g\),求多项式 \(f\) 使得 \(f[i] = \sum\limits_{j = 0}^i f[i - j] g[j]\).

其实也就是求一个无标号计数(背包)。

考虑对于 \(f\) 的指数进行类似 cdq 的分治。换言之,对于指数在 \([l, r]\) 之间的项,考虑 \([l, \lfloor \frac{l + r}{2} \rfloor]\) 之间的项对于 \([\lfloor \frac{l + r}{2} \rfloor + 1, r]\) 之间的项的贡献。

类似背包的思路,每次将 \([l, \lfloor \frac{l + r}{2} \rfloor]\) 之间的项与 \(G\) 的第 \([0, r - l]\) 项卷起来,得到的多项式的第 \(k\) 项即为左半区间对第 \(l + k\) 项的贡献。

按照中序分治整个序列,每层的复杂度为卷积 \(O(n \log n)\),共分治 \(O(\log n)\) 层,总复杂度为 \(O(n \log^2 n)\).

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;

typedef long long ll;

const int lg_sz = 18;
const int maxn = (1 << lg_sz) + 1;
const int mod = 998244353;
const int g = 3;

int n;
int rev[maxn];
ll wp[maxn], F[maxn], G[maxn], Ft[maxn], Gt[maxn];

void calc_rev(int k) { for (int i = 1; i < k; i++) rev[i] = (rev[i >> 1] >> 1 | (i & 1 ? k >> 1 : 0)); }

ll qpow(ll base, ll power, ll mod)
{
    ll res = 1;
    while (power)
    {
        if (power & 1) res = res * base % mod;
        base = base * base % mod;
        power >>= 1;
    }
    return res;
}

void NTT(ll *A, int n)
{
    calc_rev(n);
    for (int i = 1; i < n; i++)
        if (rev[i] > i) swap(A[i], A[rev[i]]);
    for (int len = 2, m = 1; len <= n; m = len, len <<= 1)
    {
        ll wn = qpow(g, (mod - 1) / len, mod);
        wp[0] = 1;
        for (int i = 1; i <= len; i++) wp[i] = wp[i - 1] * wn % mod;
        for (int l = 0, r = len - 1; r <= n; l += len, r += len)
        {
            int w = 0;
            for (int p = l; p < l + m; p++, w++)
            {
                ll x = A[p], y = wp[w] * A[p + m] % mod;
                A[p] = (x + y) % mod, A[p + m] = (x - y + mod) % mod;
            }
        }
    }
}

void INTT(ll *A, int n)
{
    NTT(A, n);
    reverse(A + 1, A + n);
    int inv = qpow(n, mod - 2, mod);
    for (int i = 0; i < n; i++) A[i] = 1ll * A[i] * inv % mod;
}

void solve(int l, int r, int lg)
{
    if ((!lg) || (l >= n)) return;
    int mid = (l + r) >> 1;
    solve(l, mid, lg - 1);
    memset(Ft + (r - l) / 2, 0, (r - l) * sizeof(ll));
    memcpy(Ft, F + l, (r - l) / 2 * sizeof(ll));
    memcpy(Gt, G, (r - l) * sizeof(ll));
    NTT(Ft, (1 << lg)), NTT(Gt, (1 << lg));
    for (int i = 0; i < r - l; i++) Ft[i] = Ft[i] * Gt[i] % mod;
    INTT(Ft, (1 << lg));
    for (int i = (r - l) / 2; i < r - l; i++) F[l + i] = (F[l + i] + Ft[i]) % mod;
    solve(mid, r, lg - 1);
}

int main()
{
    scanf("%d", &n);
    int lg = 0;
    while ((1 << lg) < n) lg++;
    for (int i = 1; i < n; i++) scanf("%lld", &G[i]);
    F[0] = 1;
    solve(0, (1 << lg), lg);
    for (int i = 0; i < n; i++) printf("%lld ", F[i]); putchar('\n');
    return 0;
}
posted @ 2023-01-13 11:52  kymru  阅读(103)  评论(0编辑  收藏  举报