Luogu 4245 【模板】任意模数NTT
这个题还有一些其他的做法,以后再补,先记一下三模数$NTT$的方法。
发现这个题不取模最大的答案不会超过$10^5 \times 10^9 \times 10^9 = 10^{23}$,也就是说我们可以取三个满足$NTT$性质的模数先算然后再合并起来。
比如三个模数可以分别取$998244353, 1004535809, 469762049$。
那么我们现在要做的就是合并三个同余方程:
$$x \equiv a_1(\mod P_1)$$
$$x \equiv a_2(\mod P_2)$$
$$x \equiv a_3(\mod P_3)$$
直接上$crt$的话会爆$long \ long$,我们需要一些其他技巧。
先用$crt$合并前两个方程,记
$$t = a_1P_2 \times inv(P_2, P_1) + a_2P_1 \times inv(P_1, P_2) $$
相当于
$$x \equiv t (\mod M = P_1P_2)$$
我们设$x = kM + t$,代入第三个方程,
$$kM + t \equiv a_3(\mod P_3)$$
可以解
$$k \equiv (a_3 - t) \times inv(M, P_3) (\mod P_3)$$
最后代回去算出$kM + t$即可。
在计算$t$的时候需要快速乘。
时间复杂度$O(nlogn)$。
注意到几个逆元没有必要计算多次,可以节省大量常数;用$O(1)$快速乘也可以大大加快速度。
Code:
// luogu-judger-enable-o2 #include <cstdio> #include <cstring> using namespace std; typedef long long ll; const int N = 3e5 + 5; const ll Mod[] = {998244353LL, 1004535809LL, 469762049LL}; int n, m, lim = 1, pos[N]; ll a[N], b[N], tmp[N], ans[3][N]; template <typename T> inline void read(T &X) { X = 0; char ch = 0; T op = 1; for (; ch > '9' || ch < '0'; ch = getchar()) if (ch == '-') op = -1; for (; ch >= '0' && ch <= '9'; ch = getchar()) X = (X << 3) + (X << 1) + ch - 48; X *= op; } template <typename T> inline void swap(T &x, T &y) { T t = x; x = y; y = t; } inline ll fmul(ll x, ll y, ll P) { ll res = 0LL; for (; y; y >>= 1) { if (y & 1) res = (res + x) % P; x = (x + x) % P; } return res; } inline ll fpow(ll x, ll y, ll P) { ll res = 1LL; /* for (; y > 0; y >>= 1) { if (y & 1) res = fmul(res, x, P); x = fmul(x, x, P); } */ for (; y > 0; y >>= 1) { if (y & 1) res = res * x % P; x = x * x % P; } return res; } inline ll getInv(ll x, ll y) { return fpow(x % y, y - 2, y); } inline void prework() { int l = 0; for (; lim <= n + m; lim <<= 1, ++l); for (int i = 0; i < lim; i++) pos[i] = (pos[i >> 1] >> 1) | ((i & 1) << (l - 1)); } inline void ntt(ll *c, int opt, ll P) { for (int i = 0; i < lim; i++) if (i < pos[i]) swap(c[i], c[pos[i]]); for (int i = 1; i < lim; i <<= 1) { ll wn = fpow(3, (P - 1) / (i << 1), P); if (opt == -1) wn = fpow(wn, P - 2, P); for (int len = i << 1, j = 0; j < lim; j += len) { ll w = 1LL; for (int k = 0; k < i; k++, w = w * wn % P) { ll x = c[j + k], y = w * c[j + k + i] % P; c[j + k] = (x + y) % P, c[j + k + i] = (x - y + P) % P; } } } if (opt == -1) { ll inv = getInv(lim, P); for (int i = 0; i < lim; i++) c[i] = c[i] * inv % P; } } inline void solve(int id) { for (int i = 0; i < lim; i++) tmp[i] = b[i] % Mod[id], ans[id][i] = a[i] % Mod[id]; ntt(tmp, 1, Mod[id]), ntt(ans[id], 1, Mod[id]); for (int i = 0; i < lim; i++) ans[id][i] = ans[id][i] * tmp[i] % Mod[id]; ntt(ans[id], -1, Mod[id]); } inline ll get(int k, ll P) { ll M = (Mod[0] * Mod[1]); ll t1 = fmul(ans[0][k] * Mod[1] % M, getInv(Mod[1], Mod[0]), M); ll t2 = fmul(ans[1][k] * Mod[0] % M, getInv(Mod[0], Mod[1]), M); ll t = (t1 + t2) % M; ll invM = getInv(M, Mod[2]), c = t; t = (ans[2][k] - t % Mod[2] + Mod[2]) % Mod[2]; t = t * invM % Mod[2]; return ((M % P) * (t % P) % P + c % P) % P; } int main() { ll P; read(n), read(m), read(P); for (int i = 0; i <= n; i++) { read(a[i]); a[i] %= P; } for (int i = 0; i <= m; i++) { read(b[i]); b[i] %= P; } prework(); for (int i = 0; i < 3; i++) solve(i); /* for (int i = 0; i < 3; i++, printf("\n")) for (int j = 0; j < lim; j++) printf("%lld ", ans[i][j]); */ for (int i = 0; i <= n + m; i++) printf("%lld%c", get(i, P), i == (n + m) ? '\n' : ' '); return 0; }