FFT与NTT的模板

网上相关博客不少,这里给自己留个带点注释的模板,以后要是忘了作提醒用。

洛谷3803多项式乘法裸题为例。

 

FFT:

 1 #include <cstdio>
 2 #include <cmath>
 3 #include <cctype>
 4 #include <algorithm>
 5 #define ri readint()
 6 #define gc getchar()
 7 
 8 int readint() {
 9     int x = 0, s = 1, c = gc;
10     while (c <= 32)    c = gc;
11     if (c == '-')    s = -1, c = gc;
12     for (; isdigit(c); c = gc)    x = x * 10 + c - 48;
13     return x * s;
14 }
15 
16 const int maxn = 4 * 1e6 + 10;
17 const double PI = acos(-1.0);
18 
19 struct Complex {
20     double x, y;
21     Complex(double a = 0, double b = 0):x(a), y(b){}
22 };
23 Complex operator + (Complex A, Complex B) { return Complex(A.x + B.x, A.y + B.y); }
24 Complex operator - (Complex A, Complex B) { return Complex(A.x - B.x, A.y - B.y); }
25 Complex operator * (Complex A, Complex B) { return Complex(A.x * B.x - A.y * B.y, A.x * B.y + A.y * B.x); }
26 
27 Complex a[maxn], b[maxn];
28 int n, m;
29 int r[maxn], l, limit = 1;
30 
31 void fft(Complex *A, int type) {
32     for (int i = 0; i < limit; i++)
33         if (i < r[i])
34             std::swap(A[i], A[r[i]]);
35     //迭代方式模拟递归写法,需要理解递归是怎么做的才能看懂这个
36     for (int mid = 1; mid < limit; mid <<= 1) {
37         //本来单位根是2*PI/len,这里len替换成2*mid,2就约掉了
38         Complex Wn(cos(PI / mid), type * sin(PI / mid));
39         for (int R = mid << 1, j = 0; j < limit; j += R) {
40             Complex w(1, 0);//单位根的k次幂
41             for (int k = 0; k < mid; k++, w = w * Wn) {
42                 //蝴蝶变换
43                 Complex x = A[j+k], y = w * A[j+k+mid];
44                 A[j+k] = x + y;
45                 A[j+k+mid] = x - y;
46             }
47         }
48     }
49 }
50 
51 int main() {
52     n = ri, m = ri;
53     for (int i = 0; i <= n; i++)
54         a[i].x = ri;
55     for (int i = 0; i <= m; i++)
56         b[i].x = ri;
57 
58     while (limit <= n + m) {//长度变为2^l
59         limit <<= 1;
60         l++;
61     }
62     for (int i = 0; i < limit; i++)//二进制镜像
63         r[i] = (r[i>>1] >> 1) | ((i&1) << (l-1));
64     fft(a, 1);
65     fft(b, 1);
66     for (int i = 0; i < limit; i++)
67         a[i] = a[i] * b[i];
68     fft(a, -1);
69     for (int i = 0; i <= n + m; i++)
70         printf("%d ", (int)(a[i].x / limit + 0.5));
71     return 0;
72 }

 

 NTT是用模域取代了复数域,性质相同只是换了单位根,所以板子基本相同。我这两个相比NTT确实比FFT快一点的:

 1 #include <bits/stdc++.h>
 2 #define ll long long
 3 #define ri readll()
 4 #define gc getchar()
 5 #define rep(i, a, b) for (int i = a; i <= b; i++)
 6 using namespace std;
 7 
 8 const int P = 998244353, G = 3, Gi = 332748118, maxn = 4 * 1e6 + 5;
 9 //P的原根为3,3%P的逆元为332748118
10 //原根意味着:3^(P-1) % P = 1,其中P-1是3%P的阶,本应是φ(P),这里恰好为大素数
11 ll n, m;
12 ll a[maxn], b[maxn];
13 int limit = 1, l, r[maxn];
14 
15 ll readll() {
16     ll x = 0ll, s = 1ll;
17     char c = gc;
18     while (c <= 32)    c = gc;
19     if (c == '-')    s = -1ll, c = gc;
20     for (; isdigit(c); c = gc)    x = x * 10 + c - 48;
21     return x * s;
22 }
23 
24 ll ksm(ll a, ll b, int mod) {
25     ll res = 1ll;
26     for (; b; b >>= 1) {
27         if (b & 1)    res = res * a % mod;
28         a = a * a % mod;
29     }
30     return res;
31 }
32 
33 void NTT(ll *A, int flag) {
34     rep(i, 0, limit)
35     if (i < r[i])
36         swap(A[i], A[r[i]]);
37 
38     for (int mid = 1; mid < limit; mid <<= 1) {
39         //如果是变换则单位根为3^[(P-1)/(len)] % P,逆变换则用逆元
40         ll Wn = ksm(flag ? G : Gi, (P-1) / (mid*2), P);
41         for (int R = mid << 1, j = 0; j < limit; j += R) {
42             ll w = 1ll;
43             for (int k = 0; k < mid; k++, w = w * Wn % P) {
44                 ll x = A[j+k], y = A[j+k+mid] * w % P;
45                 A[j+k] = (x + y) % P;
46                 A[j+k+mid] = (x - y + P) % P;
47             }
48         }
49     }
50 }
51 
52 int main() {
53     n = ri, m = ri;
54     rep(i, 0, n)    a[i] = (ri + P) % P;
55     rep(i, 0, m)    b[i] = (ri + P) % P;
56 
57     while (limit < n + m + 1) {
58         limit <<= 1;
59         l++;
60     }
61     rep(i, 0, limit)    r[i] = (r[i>>1] >> 1) | ((i & 1) << (l - 1));
62     NTT(a, 1);    NTT(b, 1);
63     rep(i, 0, limit)    a[i] = a[i] * b[i] % P;
64     NTT(a, 0);
65 
66     ll inv = ksm(limit, P - 2, P);//最后变换回来要乘长度的逆元
67     rep(i, 0, n + m)    printf("%lld ", a[i] * inv % P);
68 
69     return 0;
70 }

 

posted @ 2019-01-15 12:42  AlphaWA  阅读(288)  评论(0编辑  收藏  举报