【模板】任意模数多项式乘法:三模 NTT
前置知识
https://www.cnblogs.com/caijianhong/p/template-crt.html
https://www.cnblogs.com/caijianhong/p/template-fft.html
题目描述
任意模数多项式乘法
solution
首先我们打开 https://blog.miskcoo.com/2014/07/fft-prime-table 这篇文章找到 \(998244353\) 附近的几个质数:
- \(167772161 = 5 \times 2^ {25} + 1\) 的原根为 $ g = 3$
- \(469762049 = 7 \times 2^ {26} + 1\) 的原根为 $ g = 3$
- \(998244353 = 119 \times 2^ {23} + 1\) 的原根为 $ g = 3$
- \(1004535809 = 479 \times 2^ {21} + 1\) 的原根为 $ g = 3$
- \(2013265921 = 15 \times 2^ {27} + 1\) 的原根为 $ g = 31$
因为最终的值域为 \(p^2\max(n, m)\) 大约为 \(10^{24}\),可以选择三个质数,使他们乘积 \(>10^{24}\)。比如说可以选 \(998244353\) 与它前面的两个质数,刚好原根都是 \(3\)。
然后以三个质数为模数分别跑三次 NTT 求卷积。
最后 CRT 合并答案,就是正常的中国剩余定理。
实现
注意事项:
std::is_integral<__int128>::value
在 -std=c++14 下为 false。- CF 只有 C++20 支持
__int128
。
点击查看代码
#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
typedef long long LL;
template <class T>
using must_int = enable_if_t<is_integral<T>::value, int>;
template <unsigned umod>
struct modint {
static constexpr int mod = umod;
unsigned v;
modint() : v(0) {}
template <class T, must_int<T> = 0>
modint(T x) {
x %= mod;
v = x < 0 ? x + mod : x;
}
modint operator+() const { return *this; }
modint operator-() const { return modint() - *this; }
friend int raw(const modint &self) { return self.v; }
friend ostream &operator<<(ostream &os, const modint &self) {
return os << raw(self);
}
modint &operator+=(const modint &rhs) {
v += rhs.v;
if (v >= umod) v -= umod;
return *this;
}
modint &operator-=(const modint &rhs) {
v -= rhs.v;
if (v >= umod) v += umod;
return *this;
}
modint &operator*=(const modint &rhs) {
v = 1ull * v * rhs.v % umod;
return *this;
}
modint &operator/=(const modint &rhs) {
assert(rhs.v);
return *this *= qpow(rhs, mod - 2);
}
template <class T, must_int<T> = 0>
friend modint qpow(modint a, T b) {
modint r = 1;
for (; b; b >>= 1, a *= a)
if (b & 1) r *= a;
return r;
}
friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
bool operator==(const modint &rhs) const { return v == rhs.v; }
bool operator!=(const modint &rhs) const { return v != rhs.v; }
};
int glim(int n) { return 1 << (32 - __builtin_clz(n - 1)); }
int bitctz(int n) { return __builtin_ctz(n); }
template <class mint>
void ntt(vector<mint> &a, int op) {
static vector<mint> wns;
if (wns.empty()) {
int exp = mint::mod - 1;
while (exp % 2 == 0) wns.push_back(qpow(mint(3), exp >>= 1));
}
int n = a.size();
for (int i = 1, r = 0; i < n; i++) {
r ^= n - (1 << (bitctz(n) - bitctz(i) - 1));
debug("i = %d, r = %d\n", i, r);
if (r > i) swap(a[r], a[i]);
}
vector<mint> w(n);
for (int k = 1, len = 2; len <= n; len <<= 1, k <<= 1) {
mint wn = wns[bitctz(k)];
for (int i = raw(w[0] = 1); i < k; i++) w[i] = w[i - 1] * wn;
for (int i = 0; i < n; i += len) {
for (int j = 0; j < k; j++) {
mint x = a[i + j], y = a[i + j + k] * w[j];
a[i + j] = x + y;
a[i + j + k] = x - y;
}
}
}
if (op == -1) {
reverse(a.begin() + 1, a.end());
mint iz = mint(1) / n;
for (int i = 0; i < n; i++) a[i] *= iz;
}
}
template <class mint>
vector<mint> convolution(vector<mint> a, vector<mint> b) {
int rlen = a.size() + b.size() - 1, len = glim(rlen);
a.resize(len), ntt(a, 1);
b.resize(len), ntt(b, 1);
for (int i = 0; i < len; i++) a[i] *= b[i];
ntt(a, -1), a.resize(rlen);
return a;
}
主要的部分在这里,实现比较逆天,供参考。其中那些巨大数字由 python 程序生成:
#!/bin/env python3
n = 3
a = [998244353, 1004535809, 469762049]
M = a[0] * a[1] * a[2]
m = [M // a[i] for i in range(n)]
t = [pow(m[i], -1, a[i]) for i in range(n)] # 模意义下逆元
coe = [m[i] * t[i] % M for i in range(n)]
typ = "unsigned __int128"
prog = f"""
{typ} operator\"\"_ubi(const char* str) {{
int len = strlen(str);
{typ} x = 0;
for (int i = 0; i < len; i++) x = x * 10 + str[i] - '0';
return x;
}}
"""
print(prog)
print(*[f"({coe[i]}_ubi * raw(ret{i + 1}[i]) % {M}_ubi)" for i in range(n)], sep = " + ")
a = b = c = 1463
func = f"((401276874248923522479908641 * {a} % 471064322751194440790966273) +(185347017962817624218731910 * {b} % 471064322751194440790966273) +(355504753290647734883291996 * {c} % 471064322751194440790966273) ) % M"
print(eval(func)) # 验算
注意要先彻底模完 \(M\) 再去模 \(p\)。
__int128 operator""_ubi(const char *str) {
int len = strlen(str);
__int128 x = 0;
for (int i = 0; i < len; i++) x = x * 10 + str[i] - '0';
return x;
}
void print(__int128 x) {
const LL base = 1e18;
#ifdef LOCAL
if (x <= base)
cerr << (LL)x << endl;
else
cerr << (LL)(x / base) << (LL)(x % base) << endl;
#endif
}
vector<int> convolution_fun(const vector<int> &a, const vector<int> &b, int p) {
static constexpr unsigned mods[] = {998244353, 1004535809, 469762049};
typedef modint<mods[0]> mint1;
typedef modint<mods[1]> mint2;
typedef modint<mods[2]> mint3;
int rlen = a.size() + b.size() - 1, len = glim(rlen);
vector<int> ret(rlen);
auto ret1 = convolution(vector<mint1>(a.begin(), a.end()),
vector<mint1>(b.begin(), b.end()));
auto ret2 = convolution(vector<mint2>(a.begin(), a.end()),
vector<mint2>(b.begin(), b.end()));
auto ret3 = convolution(vector<mint3>(a.begin(), a.end()),
vector<mint3>(b.begin(), b.end()));
for (int i = 0; i < rlen; i++) {
ret[i] = ((401276874248923522479908641_ubi * raw(ret1[i]) %
471064322751194440790966273_ubi) +
(185347017962817624218731910_ubi * raw(ret2[i]) %
471064322751194440790966273_ubi) +
(355504753290647734883291996_ubi * raw(ret3[i]) %
471064322751194440790966273_ubi)) %
471064322751194440790966273_ubi % p;
}
return ret;
}
int main() {
#ifndef LOCAL
cin.tie(nullptr)->sync_with_stdio(false);
#endif
int n, m, p;
cin >> n >> m >> p;
vector<int> a(n + 1);
for (int i = 0; i <= n; i++) cin >> a[i];
vector<int> b(m + 1);
for (int i = 0; i <= m; i++) cin >> b[i];
for (int x : convolution_fun(a, b, p)) cout << x << " ";
cout << endl;
return 0;
}
dif-dif 实现
https://charleswu.site/archives/3065
#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
typedef long long LL;
template <class T>
using must_int = enable_if_t<is_integral<T>::value, int>;
template <unsigned umod>
struct modint {
static constexpr int mod = umod;
unsigned v;
modint() : v(0) {}
template <class T, must_int<T> = 0>
modint(T x) {
x %= mod;
v = x < 0 ? x + mod : x;
}
modint operator+() const { return *this; }
modint operator-() const { return modint() - *this; }
friend int raw(const modint &self) { return self.v; }
friend ostream &operator<<(ostream &os, const modint &self) {
return os << raw(self);
}
modint &operator+=(const modint &rhs) {
v += rhs.v;
if (v >= umod) v -= umod;
return *this;
}
modint &operator-=(const modint &rhs) {
v -= rhs.v;
if (v >= umod) v += umod;
return *this;
}
modint &operator*=(const modint &rhs) {
v = 1ull * v * rhs.v % umod;
return *this;
}
modint &operator/=(const modint &rhs) {
assert(rhs.v);
return *this *= qpow(rhs, mod - 2);
}
template <class T, must_int<T> = 0>
friend modint qpow(modint a, T b) {
modint r = 1;
for (; b; b >>= 1, a *= a)
if (b & 1) r *= a;
return r;
}
friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
bool operator==(const modint &rhs) const { return v == rhs.v; }
bool operator!=(const modint &rhs) const { return v != rhs.v; }
};
int glim(int n) { return 1 << (32 - __builtin_clz(n - 1)); }
int bitctz(int n) { return __builtin_ctz(n); }
template <class mint, int G>
struct NTT_env {
vector<mint> w{1};
void init(int n) {
while (w.size() < n) {
int m = w.size();
mint wn = qpow(mint(G), (mint::mod - 1) / m >> 2);
w.resize(m << 1);
for (int i = m; i < m << 1; i++) w[i] = wn * w[i ^ m];
}
}
void dif(vector<mint> &a) {
int n = a.size();
init(n);
for (int len = n, k = n >> 1; k >= 1; len = k, k >>= 1) {
for (int i = 0, t = 0; i < n; i += len, t++) {
for (int j = 0; j < k; j++) {
mint x = a[i + j], y = a[i + j + k] * w[t];
a[i + j] = x + y, a[i + j + k] = x - y;
}
}
}
}
void dit(vector<mint> &a) {
int n = a.size();
init(n);
for (int k = 1, len = 2; len <= n; k = len, len <<= 1) {
for (int i = 0, t = 0; i < n; i += len, t++) {
for (int j = 0; j < k; j++) {
mint x = a[i + j], y = a[i + j + k];
a[i + j] = x + y, a[i + j + k] = (x - y) * w[t];
}
}
}
mint iv = mint(1) / n;
for (int i = 0; i < n; i++) a[i] *= iv;
reverse(a.begin() + 1, a.end());
}
};
template <class mint>
vector<mint> convolution(vector<mint> a, vector<mint> b) {
static NTT_env<mint, 3> ntt;
int rlen = a.size() + b.size() - 1, len = glim(rlen);
a.resize(len), ntt.dif(a);
b.resize(len), ntt.dif(b);
for (int i = 0; i < len; i++) a[i] *= b[i];
ntt.dit(a), a.resize(rlen);
return a;
}
__int128 operator""_ubi(const char *str) {
int len = strlen(str);
__int128 x = 0;
for (int i = 0; i < len; i++) x = x * 10 + str[i] - '0';
return x;
}
void print(__int128 x) {
const LL base = 1e18;
#ifdef LOCAL
if (x <= base)
cerr << (LL)x << endl;
else
cerr << (LL)(x / base) << (LL)(x % base) << endl;
#endif
}
vector<int> convolution_fun(const vector<int> &a, const vector<int> &b, int p) {
static constexpr unsigned mods[] = {5 << 25 | 1, 7 << 26 | 1, 119 << 23 | 1};
typedef modint<mods[0]> mint1;
typedef modint<mods[1]> mint2;
typedef modint<mods[2]> mint3;
int rlen = a.size() + b.size() - 1, len = glim(rlen);
vector<int> ret(rlen);
auto ret1 = convolution(vector<mint1>(a.begin(), a.end()),
vector<mint1>(b.begin(), b.end()));
auto ret2 = convolution(vector<mint2>(a.begin(), a.end()),
vector<mint2>(b.begin(), b.end()));
auto ret3 = convolution(vector<mint3>(a.begin(), a.end()),
vector<mint3>(b.begin(), b.end()));
for (int i = 0; i < rlen; i++) {
ret[i] = (13862981345800242111649459_ubi * raw(ret1[i]) +
19425833427644834021761124_ubi * raw(ret2[i]) +
45385811546391130584320235_ubi * raw(ret3[i])) %
78674626319836206717730817_ubi % p;
}
return ret;
}
int main() {
#ifndef LOCAL
cin.tie(nullptr)->sync_with_stdio(false);
#endif
int n, m, p;
cin >> n >> m >> p;
vector<int> a(n + 1);
for (int i = 0; i <= n; i++) cin >> a[i];
vector<int> b(m + 1);
for (int i = 0; i <= m; i++) cin >> b[i];
for (int x : convolution_fun(a, b, p)) cout << x << " ";
cout << endl;
return 0;
}
本文来自博客园,作者:caijianhong,转载请注明原文链接:https://www.cnblogs.com/caijianhong/p/18127492