Luogu 4245 【模板】任意模数NTT

这个题还有一些其他的做法,以后再补,先记一下三模数$NTT$的方法。

发现这个题不取模最大的答案不会超过$10^5 \times 10^9 \times 10^9 = 10^{23}$,也就是说我们可以取三个满足$NTT$性质的模数先算然后再合并起来。

比如三个模数可以分别取$998244353, 1004535809, 469762049$。

那么我们现在要做的就是合并三个同余方程:

$$x \equiv a_1(\mod P_1)$$

$$x \equiv a_2(\mod P_2)$$

$$x \equiv a_3(\mod P_3)$$

直接上$crt$的话会爆$long \ long$,我们需要一些其他技巧。

先用$crt$合并前两个方程,记

$$t = a_1P_2 \times inv(P_2, P_1) + a_2P_1 \times inv(P_1, P_2) $$

相当于

$$x \equiv t (\mod M = P_1P_2)$$

我们设$x = kM + t$,代入第三个方程,

$$kM + t \equiv a_3(\mod P_3)$$

可以解

$$k \equiv (a_3 - t) \times inv(M, P_3) (\mod P_3)$$

最后代回去算出$kM + t$即可。

在计算$t$的时候需要快速乘。

时间复杂度$O(nlogn)$。

注意到几个逆元没有必要计算多次,可以节省大量常数;用$O(1)$快速乘也可以大大加快速度。

Code:

// luogu-judger-enable-o2
#include <cstdio>
#include <cstring>
using namespace std;
typedef long long ll;

const int N = 3e5 + 5;
const ll Mod[] = {998244353LL, 1004535809LL, 469762049LL};

int n, m, lim = 1, pos[N];
ll a[N], b[N], tmp[N], ans[3][N];

template <typename T>
inline void read(T &X) {
    X = 0; char ch = 0; T op = 1;
    for (; ch > '9' || ch < '0'; ch = getchar())
        if (ch == '-') op = -1;
    for (; ch >= '0' && ch <= '9'; ch = getchar())
        X = (X << 3) + (X << 1) + ch - 48;
    X *= op;
}

template <typename T>
inline void swap(T &x, T &y) {
    T t = x; x = y; y = t;
} 

inline ll fmul(ll x, ll y, ll P) {
    ll res = 0LL;
    for (; y; y >>= 1) {
        if (y & 1) res = (res + x) % P;
        x = (x + x) % P;
    }
    return res;
}

inline ll fpow(ll x, ll y, ll P) {
    ll res = 1LL;
/*    for (; y > 0; y >>= 1) {
        if (y & 1) res = fmul(res, x, P);
        x = fmul(x, x, P);
    }   */
    
    for (; y > 0; y >>= 1) {
        if (y & 1) res = res * x % P;
        x = x * x % P;
    }
    
    return res;
}

inline ll getInv(ll x, ll y) {
    return fpow(x % y, y - 2, y);
}

inline void prework() {
    int l = 0;
    for (; lim <= n + m; lim <<= 1, ++l);
    for (int i = 0; i < lim; i++)
        pos[i] = (pos[i >> 1] >> 1) | ((i & 1) << (l - 1));
}

inline void ntt(ll *c, int opt, ll P) {
    for (int i = 0; i < lim; i++)
        if (i < pos[i]) swap(c[i], c[pos[i]]);
    for (int i = 1; i < lim; i <<= 1) {
        ll wn = fpow(3, (P - 1) / (i << 1), P);
        if (opt == -1) wn = fpow(wn, P - 2, P);
        for (int len = i << 1, j = 0; j < lim; j += len) {
            ll w = 1LL;
            for (int k = 0; k < i; k++, w = w * wn % P) {
                ll x = c[j + k], y = w * c[j + k + i] % P;
                c[j + k] = (x + y) % P, c[j + k + i] = (x - y + P) % P;
            }
        }
    }
    
    if (opt == -1) {
        ll inv = getInv(lim, P);
        for (int i = 0; i < lim; i++) c[i] = c[i] * inv % P;
    }
}

inline void solve(int id) {
    for (int i = 0; i < lim; i++) 
        tmp[i] = b[i] % Mod[id], ans[id][i] = a[i] % Mod[id];
    ntt(tmp, 1, Mod[id]), ntt(ans[id], 1, Mod[id]);
    for (int i = 0; i < lim; i++) ans[id][i] = ans[id][i] * tmp[i] % Mod[id];
    ntt(ans[id], -1, Mod[id]);
}

inline ll get(int k, ll P) {
    ll M = (Mod[0] * Mod[1]);
    ll t1 = fmul(ans[0][k] * Mod[1] % M, getInv(Mod[1], Mod[0]), M);
    ll t2 = fmul(ans[1][k] * Mod[0] % M, getInv(Mod[0], Mod[1]), M);
    ll t = (t1 + t2) % M;
    ll invM = getInv(M, Mod[2]), c = t;
    t = (ans[2][k] - t % Mod[2] + Mod[2]) % Mod[2];
    t = t * invM % Mod[2];
    return ((M % P) * (t % P) % P + c % P) % P;
}

int main() {
    ll P;
    read(n), read(m), read(P);
    for (int i = 0; i <= n; i++) {
        read(a[i]);
        a[i] %= P;
    }
    for (int i = 0; i <= m; i++) {
        read(b[i]);
        b[i] %= P;
    }
    
    prework();
    for (int i = 0; i < 3; i++) solve(i);
    
/*    for (int i = 0; i < 3; i++, printf("\n"))
        for (int j = 0; j < lim; j++)
            printf("%lld ", ans[i][j]);    */
    
    for (int i = 0; i <= n + m; i++) 
        printf("%lld%c", get(i, P), i == (n + m) ? '\n' : ' ');
        
    return 0;
}
View Code

 

posted @ 2019-01-16 14:19  CzxingcHen  阅读(229)  评论(0编辑  收藏  举报