FFT | NTT

FFT:

#include <bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
#define db double
#define Pi acos(-1)
#define eps 1e-8
#define N 400005
using namespace std;
template <typename T> T sqr(T x) { return x * x; }
struct Complex {
    db x, y;
    Complex (db xx = 0, db yy = 0) { x = xx; y = yy; }
    Complex operator + (Complex B) { return Complex(x + B.x, y + B.y); }
    Complex operator - (Complex B) { return Complex(x - B.x, y - B.y); }
    Complex operator * (Complex B) { return Complex(x * B.x - y * B.y, x * B.y + y * B.x); }
}a[N], b[N], c[N], d[N];
int r[N];
int n, m, limit, l;
void FFT(Complex *a, int type) {
    for (int i = 0; i < limit; i ++) if (i < r[i]) swap(a[i], a[r[i]]);
    for (int mid = 1; mid < limit; mid <<= 1) {
        Complex wn(cos(Pi / mid), type * sin(Pi / mid));
        for (int R = mid << 1, j = 0; j < limit; j += R) {
            Complex w(1, 0);
            for (int k = 0; k < mid; k ++, w = w * wn) {
                Complex x = a[j + k], y = w * a[j + mid + k];
                a[j + k] = x + y; a[j + mid + k] = x - y;
            }
        } 
    }
    if (type == -1) for (int i = 0; i < limit; i ++) a[i].x /= 1.0 * limit;
    // if( type == -1 ) for(int i=0;i<limit;i++)a[i].x = (int)(a[i].x/limit+0.5);
}
int main() {
    scanf("%d",&n);
    for (int i = 1; i <= n; i ++) {
        db x; scanf("%lf", &x);
        a[i].x = c[n - i + 1].x = x; b[i].x = 1.0 / sqr(i * 1.0);
    }
    limit = 1; while (limit <= (n << 1)) limit <<= 1, l ++;
    for (int i = 0; i < limit; i ++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    FFT(a, 1); FFT(b, 1);FFT(c, 1);
    for (int i = 0; i < limit; i ++) a[i] = a[i] * b[i];
    FFT(a, -1);
    for (int i = 0; i < limit; i ++) c[i] = c[i] * b[i];
    FFT(c, -1);
    for (int i = 1; i <= n; i ++) printf("%.3lf\n", a[i].x - c[n - i + 1].x);
    return 0;
}

NTT

#include <bits/stdc++.h>

#define ed end()
#define bg begin()
#define mp make_pair
#define pb push_back
#define all(x) x.bg,x.ed
#define newline puts("")
#define si(x) ((int)x.size())
#define rep(i,n) for(int i=1;i<=n;++i)
#define rrep(i,n) for(int i=0;i<n;++i)
#define srep(i,s,t) for(int i=s;i<=t;++i)
#define drep(i,s,t) for(int i=t;i>=s;--i)

#define DEBUG
#define d1(x) std::cout << #x " = " << (x) << std::endl
#define d2(x, y) std::cout << #x " = " << (x) << " ," #y " = " << (y) << std::endl
#define disp(arry, fr, to) \
    { \
        std::cout << #arry " : "; \
        for(int _i = fr; _i <= to; _i++) std::cout << arry[_i] << " "; \
        std::cout << std::endl; \
    }

using namespace std;
typedef long long LL;
typedef long long ll;
typedef pair<int,int> pii;
const int Maxn = 4e5 + 10;
const int Inf = 0x7f7f7f7f;
const LL Inf_ll = 1ll*Inf*Inf;
const int Mod = 998244353;
const double eps = 1e-7;

const int MAXN = 4e5 + 10, P = 998244353, G = 3, Gi = 332748118; 
int N, M, limit = 1, L, r[MAXN];

inline int add(const int x, const int y){
    return x+y < Mod ? x+y : x+y-Mod;
}
inline int sub(const int x, const int y){
    return x-y >= 0 ? x-y : x-y+Mod;
}
inline int mul(const int x, const int y){
    return 1ll*x*y%Mod;
}
inline int mypow(int a, int b){
    int ret = 1;
    while( b > 0 ){
        if( b&1 )  ret = mul(ret, a);
        a = mul(a, a);
        b >>= 1;
    }
    return ret;
}
inline int inv_int(int a){
    return mypow(a, Mod-2);
}

inline void NTT(int *A, int type) {
	for(int i = 0; i < limit; i++) 
		if(i < r[i]) swap(A[i], A[r[i]]);
	for(int mid = 1; mid < limit; mid <<= 1) {	
		int Wn = mypow( type == 1 ? G : Gi , (P - 1) / (mid << 1));
		for(int j = 0; j < limit; j += (mid << 1)) {
			int w = 1;
			for(int k = 0; k < mid; k++, w = mul(w, Wn)) {
				int x = A[j + k], y = mul(w, A[j+k+mid]);
				A[j + k] = add(x, y);
				A[j + k + mid] = sub(x, y);
			}
		}
	}
    if( type == -1 ){
        int inv = inv_int(limit);
        for(int i=0;i<limit;i++)
            A[i] = mul(A[i], inv);
    }
}

inline int read(){
    char c = getchar();int x = 0,fh = 0;
    while(c < '0' || c > '9'){fh |= c == '-';c = getchar();}
    while(c >= '0' && c <= '9'){x = (x << 1) + (x << 3) + (c ^ 48);c = getchar();}
    return fh?-x:x;
}
void print(int x){
    if( x < 0 )  x = -x;
    if( x >= 10 ) print(x/10);
    putchar(x%10+'0');
}

inline void NTT_mul(int *f, int *g, int *h, int n, int m){
    if( min(n, m) < 128 )
    {
        for(int i=0;i<=n+m;i++)  h[i] = 0;
        for(int i=0;i<=n;i++)  for(int j=0;j<=m;j++)
            h[i+j] = add(h[i+j], mul(f[i], g[j]));
        return ;
    }
    limit = 1, L = 0;
	while(limit <= n + m) limit <<= 1, L++;
	for(int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
    for(int i=n+1;i<limit;i++)  f[i] = 0;
    for(int i=m+1;i<limit;i++)  g[i] = 0;
    NTT(f, 1), NTT(g, 1);
    for(int i=0;i<limit;i++)  h[i] = mul(f[i], g[i]);
    NTT(h, -1);
}

int A[Maxn], B[Maxn], C[Maxn];

// vector<int> solve(int l, int r){
//     vector<int> now;
//     if( l == r )
//     {
//         now = vector<int>{1,a[l]};
//         return now;
//     }
//     int len = r-l+1, mid = (r+l) >> 1, l1 = mid-l+1, l2 = r-mid;
//     vector<int> vl = solve(l, mid), vr = solve(mid+1, r);
//     for(int i=0;i<=l1;i++)  A[i] = vl[i];
//     for(int i=0;i<=l2;i++)  B[i] = vr[i];
//     NTT_mul(A, B, C, l1, l2);
//     now.resize(len+1);
//     for(int i=0;i<=len;i++)  now[i] = C[i];
//     return now;
// }

int n, m, a[Maxn], b[Maxn], c[Maxn];
int main() {
    n = read(), m = read();
    for(int i=0;i<=n;i++)  a[i] = read();
    for(int i=0;i<=m;i++)  b[i] = read();
    NTT_mul(a, b, c, n, m);
    for(int i=0;i<=n+m;i++)  print(c[i]), putchar(' ');
	return 0;
}

任意模数 NTT

#include <algorithm>
#include <cstdio>
#include <cstring>
int mod;
namespace Math {
    inline int pw(int base, int p, const int mod) {
        static int res;
        for (res = 1; p; p >>= 1, base = static_cast<long long> (base) * base % mod) if (p & 1) res = static_cast<long long> (res) * base % mod;
        return res;
    }
    inline int inv(int x, const int mod) { return pw(x, mod - 2, mod); }
}
 
const int mod1 = 998244353, mod2 = 1004535809, mod3 = 469762049, G = 3;
const long long mod_1_2 = static_cast<long long> (mod1) * mod2;
const int inv_1 = Math::inv(mod1, mod2), inv_2 = Math::inv(mod_1_2 % mod3, mod3);
struct Int {
    int A, B, C;
    explicit inline Int() { }
    explicit inline Int(int __num) : A(__num), B(__num), C(__num) { }
    explicit inline Int(int __A, int __B, int __C) : A(__A), B(__B), C(__C) { }
    static inline Int reduce(const Int &x) {
        return Int(x.A + (x.A >> 31 & mod1), x.B + (x.B >> 31 & mod2), x.C + (x.C >> 31 & mod3));
    }
    inline friend Int operator + (const Int &lhs, const Int &rhs) {
        return reduce(Int(lhs.A + rhs.A - mod1, lhs.B + rhs.B - mod2, lhs.C + rhs.C - mod3));
    }
    inline friend Int operator - (const Int &lhs, const Int &rhs) {
        return reduce(Int(lhs.A - rhs.A, lhs.B - rhs.B, lhs.C - rhs.C));
    }
    inline friend Int operator * (const Int &lhs, const Int &rhs) {
        return Int(static_cast<long long> (lhs.A) * rhs.A % mod1, static_cast<long long> (lhs.B) * rhs.B % mod2, static_cast<long long> (lhs.C) * rhs.C % mod3);
    }
    inline int get() {
        long long x = static_cast<long long> (B - A + mod2) % mod2 * inv_1 % mod2 * mod1 + A;
        return (static_cast<long long> (C - x % mod3 + mod3) % mod3 * inv_2 % mod3 * (mod_1_2 % mod) % mod + x) % mod;
    }
} ;
 
#define maxn 131072
 
namespace Poly {
#define N (maxn << 1)
    int lim, s, rev[N];
    Int Wn[N | 1];
    inline void init(int n) {
        s = -1, lim = 1; while (lim < n) lim <<= 1, ++s;
        for (register int i = 1; i < lim; ++i) rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
        const Int t(Math::pw(G, (mod1 - 1) / lim, mod1), Math::pw(G, (mod2 - 1) / lim, mod2), Math::pw(G, (mod3 - 1) / lim, mod3));
        *Wn = Int(1); for (register Int *i = Wn; i != Wn + lim; ++i) *(i + 1) = *i * t;
    }
    inline void NTT(Int *A, const int op = 1) {
        for (register int i = 1; i < lim; ++i) if (i < rev[i]) std::swap(A[i], A[rev[i]]);
        for (register int mid = 1; mid < lim; mid <<= 1) {
            const int t = lim / mid >> 1;
            for (register int i = 0; i < lim; i += mid << 1) {
                for (register int j = 0; j < mid; ++j) {
                    const Int W = op ? Wn[t * j] : Wn[lim - t * j];
                    const Int X = A[i + j], Y = A[i + j + mid] * W;
                    A[i + j] = X + Y, A[i + j + mid] = X - Y;
                }
            }
        }
        if (!op) {
            const Int ilim(Math::inv(lim, mod1), Math::inv(lim, mod2), Math::inv(lim, mod3));
            for (register Int *i = A; i != A + lim; ++i) *i = (*i) * ilim;
        }
    }
#undef N
}
 
int n, m;
Int A[maxn << 1], B[maxn << 1];
int main() {
    scanf("%d%d%d", &n, &m, &mod); ++n, ++m;
    for (int i = 0, x; i < n; ++i) scanf("%d", &x), A[i] = Int(x % mod);
    for (int i = 0, x; i < m; ++i) scanf("%d", &x), B[i] = Int(x % mod);
    Poly::init(n + m);
    Poly::NTT(A), Poly::NTT(B);
    for (int i = 0; i < Poly::lim; ++i) A[i] = A[i] * B[i];
    Poly::NTT(A, 0);
    for (int i = 0; i < n + m - 1; ++i) {
        printf("%d", A[i].get());
        putchar(i == n + m - 2 ? '\n' : ' ');
    }
    return 0;
}
posted @ 2020-08-05 10:54  HexQwQ  阅读(121)  评论(0编辑  收藏  举报