[模版] 快速傅里叶变换

原理见:快速傅里叶变换 - OI Wiki

大概的理解:多项式函数系数函数点值的变换和逆变换,利用单位复根的”旋转“性质实现,分治法实现O(nlogn)。

应用:多项式乘法,卷积的加速(一些dp式子的加速),字符串匹配(实现不完全匹配),大数相乘......

模版:

有递归实现和非递归实现,递归实现直观但较慢;非递归实现需要先对系数数组进行蝴蝶变换,即重新排序,使得可以两两合并,类似非递归的归并排序。

  • 非递归普通fft

普通的fft受到double的精度限制。如果卷积过程中最大值超过1e17,就要使用long double;如果超过了1e23,一般要取模,则使用ntt或者mtt。

通过预处理所有单位根可以提高精度,因为如果在计算过程中迭代计算单位根会有复数乘法带来的精度误差。

使用stl的复数模版可能会很慢,推荐使用手写的。

struct cp {
    double x, y;
    cp(double xx = 0, double yy = 0) { x = xx, y = yy; }
    cp operator+(cp a) { return cp(x + a.x, y + a.y); }
    cp operator-(cp a) { return cp(x - a.x, y - a.y); }
    cp operator*(cp a) { return cp(x * a.x - y * a.y, x * a.y + y * a.x); }
};
int rev[N];
// 非递归模板
// 同样需要保证 len 是 2 的幂
// 记 rev[i] 为 i 翻转后的值
void change(cp y[], int len) { // 蝴蝶变换
    for(int i = 0; i < len; ++i) {
        rev[i] = rev[i >> 1] >> 1;
        if(i & 1) {  // 如果最后一位是 1,则翻转成 len/2
            rev[i] |= len >> 1;
        }
    }
    for(int i = 0; i < len; ++i) {
        if(i < rev[i]) {  // 保证每对数只翻转一次
            swap(y[i], y[rev[i]]);
        }
    }
    return;
}
/*
 * 做 FFT
 * len 必须是 2^k 形式
 * on == 1 时是 DFT,on == -1 时是 IDFT
 */
void fft(cp y[], int len, int on) {
    change(y, len);
    for(int h = 2; h <= len; h <<= 1) {                  // 模拟合并过程
        cp wn(cos(2 * PI / h), sin(on * 2 * PI / h));  // 计算当前单位复根
        for(int j = 0; j < len; j += h) {
            cp w(1, 0);  // 计算当前单位复根
            for(int k = j; k < j + h / 2; k++) {
              cp u = y[k];
              cp t = w * y[k + h / 2];
              y[k] = u + t;  // 这就是把两部分分治的结果加起来
              y[k + h / 2] = u - t;
              // 后半个 “step” 中的ω一定和 “前半个” 中的成相反数
              // “红圈”上的点转一整圈“转回来”,转半圈正好转成相反数
              // 一个数相反数的平方与这个数自身的平方相等
              w = w * wn;
            }
        }
    }
    if(on == -1) {
        for(int i = 0; i < len; i++) {
            y[i].x /= len;
        }
    }
}
  • 数论变换ntt

利用原根在模的意义下单位根。如果模数为\(P=2^n+1\),最大支持\(2^n\)的卷积。常用模数为\(998244353\)\(1004535809\)\(469762049\),它们的原根都是3。

const int N = 3e5 + 10;
const int M = 998244353;
int rev[N];
inline ll qpow(ll a, ll b, ll m) {
    ll res = 1;
    while (b) {
        if (b & 1)
            res = (res * a) % m;

        a = (a * a) % m;
        b = b >> 1;
    }
    return res;
}

void change(ll y[], int len) { // 蝴蝶变换
    for (int i = 0; i < len; ++i) {
        rev[i] = rev[i >> 1] >> 1;
        if (i & 1) {
            rev[i] |= len >> 1;
        }
    }
    for (int i = 0; i < len; ++i) {
        if (i < rev[i]) {
            swap(y[i], y[rev[i]]);
        }
    }
    return;
}

void ntt(ll y[], int len, int on) { // -1逆变换
    change(y, len);
    for (int h = 2; h <= len; h <<= 1) {
        ll gn = qpow(3, (M - 1) / h, M); // 原根为3
        if (on == -1)
            gn = qpow(gn, M - 2, M);
        for (int j = 0; j < len; j += h) {
            ll g = 1;

            for (int k = j; k < j + h / 2; k++) {
                ll u = y[k];
                ll t = g * y[k + h / 2] % M;
                y[k] = (u + t) % M;
                y[k + h / 2] = (u - t + M) % M;
                g = g * gn % M;
            }
        }
    }
    if (on == -1) {
        ll inv = qpow(len, M - 2, M);
        for (int i = 0; i < len; i++) {
            y[i] = y[i] * inv % M;
        }
    }
}
  • 任意模数mtt

有时候, 题目要求对某个\(P\)取模,而\(P\)无法为写成\(2^n+1\),无法使用ntt;而且\(P\)很大,导致卷积过程中最大值\(P^2n^2\)(因为idft还要除以n)超出了doublelong double的精度限制。这时有几种方法解决:

  1. 中国剩余定理

一般选用多个质数得到结果再用crt合并。只要\(P^2n\)小于选取的质数之积,就可以表示。

一般选用998244353,1004535809,469762049,可以表示\(10^{26}\)的值域,而且原根均为3,好写。

时间复杂度为9个ntt操作。太慢,一般不用。

  1. 拆系数

将多项式每个系数写成\(x_1\sqrt{P}+x_2\)的形式,可得

\[(x_1\sqrt{P}+x_2)(y_1\sqrt{P}+y_2)=Px_1y_1+\sqrt{P}(x_1y_2+x_2y_1)+x_2y_2 \]

这样每一项都小于\(\sqrt{P}\),需要卷积4组多项式,值域变为\(Pn^2\)(因为idft还要除以n),瞬间小很多,可以用fft解决了。

时间复杂度7次fft,[模版]任意模数fft

#include <bits/stdc++.h>

#define endl '\n'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define mp make_pair
#define seteps(N) fixed << setprecision(N) 
typedef long long ll;

using namespace std;
/*-----------------------------------------------------------------*/

ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f

const int N = 3e5 + 10;
const double eps = 1e-5;
long double PI = acos((long double)-1);
int rev[N];

struct cp {
    long double x, y;
    cp(long double xx = 0, long double yy = 0) { x = xx, y = yy; }
    cp operator+(cp a) { return cp(x + a.x, y + a.y); }
    cp operator-(cp a) { return cp(x - a.x, y - a.y); }
    cp operator*(cp a) { return cp(x * a.x - y * a.y, x * a.y + y * a.x); }
};

void change(cp y[], int len) { // 蝴蝶变换
    for(int i = 0; i < len; ++i) {
        rev[i] = rev[i >> 1] >> 1;
        if(i & 1) {  // 如果最后一位是 1,则翻转成 len/2
            rev[i] |= len >> 1;
        }
    }
    for(int i = 0; i < len; ++i) {
        if(i < rev[i]) {  // 保证每对数只翻转一次
            swap(y[i], y[rev[i]]);
        }
    }
    return;
}
void fft(cp y[], int len, int on) {
    change(y, len);
    for(int h = 2; h <= len; h <<= 1) {
        cp wn(cos(2 * PI / h), sin(on * 2 * PI / h));
        for(int j = 0; j < len; j += h) {
            cp w(1, 0);  // 计算当前单位复根
            for(int k = j; k < j + h / 2; k++) {
              cp u = y[k];
              cp t = w * y[k + h / 2];
              y[k] = u + t;
              y[k + h / 2] = u - t;
              w = w * wn;
            }
        }
    }
    if(on == -1) {
        for(int i = 0; i < len; i++) {
            y[i].x /= len;
            y[i].y /= len;
        }
    }
}

cp ax[N], ay[N], bx[N], by[N];
cp rpp[N], rp[N], r[N];

void mtt(int x[], int y[], int res[], int len, int P) { // x,y卷积,P模数,结果放在res
	ll p = 1;
	while(p * p < P) p++;
	for(int i = 0; i < len; i++) {
        ax[i] = x[i] / p;
        bx[i] = x[i] % p;
    }
    for(int i = 0; i < len; i++) {
        ay[i] = y[i] / p;
        by[i] = y[i] % p;
    }
	fft(ax, len, 1);
    fft(ay, len, 1);
    fft(bx, len, 1);
    fft(by, len, 1);

	for(int i = 0; i < len; i++) {
        rpp[i] = ax[i] * ay[i];
        rp[i] = ax[i] * by[i] + ay[i] * bx[i];
        r[i] = bx[i] * by[i];
    }

	fft(rpp, len, -1);
    fft(rp, len, -1);
    fft(r, len, -1);

	for(int i = 0; i < len; i++) {
		res[i] = (p * p % P * (ll)round(rpp[i].x) % P % P + p * (ll)round(rp[i].x) % P + (ll)round(r[i].x) % P) % P;
	}
}

int get(int x) {
    int res = 1;
    while(res < x) {
        res <<= 1;
    }
    return res;
}

int res[N], x[N], y[N];
int main() {
    IOS;
    int n, m, P;
    cin >> n >> m >> P;
    n++;
    m++;
	for(int i = 0; i < n; i++) {
		cin >> x[i];
		x[i] %= P;
	}
	for(int i = 0; i < m; i++) {
		cin >> y[i];
		y[i] %= P;
	}
    int len = 2 * max(get(n), get(m));
    mtt(x, y, res, len, P);
    for(int i = 0; i < n + m - 1; i++) {
       cout << res[i] << " \n"[i == n + m - 2];
    }
}
posted @ 2021-08-02 10:19  limil  阅读(111)  评论(0编辑  收藏  举报