【模板】任意模数多项式乘法:三模 NTT

前置知识

https://www.cnblogs.com/caijianhong/p/template-crt.html

https://www.cnblogs.com/caijianhong/p/template-fft.html

题目描述

任意模数多项式乘法

solution

首先我们打开 https://blog.miskcoo.com/2014/07/fft-prime-table 这篇文章找到 \(998244353\) 附近的几个质数:

  • \(167772161 = 5 \times 2^ {25} + 1\) 的原根为 $ g = 3$
  • \(469762049 = 7 \times 2^ {26} + 1\) 的原根为 $ g = 3$
  • \(998244353 = 119 \times 2^ {23} + 1\) 的原根为 $ g = 3$
  • \(1004535809 = 479 \times 2^ {21} + 1\) 的原根为 $ g = 3$
  • \(2013265921 = 15 \times 2^ {27} + 1\) 的原根为 $ g = 31$

因为最终的值域为 \(p^2\max(n, m)\) 大约为 \(10^{24}\),可以选择三个质数,使他们乘积 \(>10^{24}\)。比如说可以选 \(998244353\) 与它前面的两个质数,刚好原根都是 \(3\)

然后以三个质数为模数分别跑三次 NTT 求卷积。

最后 CRT 合并答案,就是正常的中国剩余定理。

实现

注意事项:

  1. std::is_integral<__int128>::value 在 -std=c++14 下为 false。
  2. CF 只有 C++20 支持 __int128
点击查看代码

#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
typedef long long LL;
template <class T>
using must_int = enable_if_t<is_integral<T>::value, int>;
template <unsigned umod>
struct modint {
  static constexpr int mod = umod;
  unsigned v;
  modint() : v(0) {}
  template <class T, must_int<T> = 0>
  modint(T x) {
    x %= mod;
    v = x < 0 ? x + mod : x;
  }
  modint operator+() const { return *this; }
  modint operator-() const { return modint() - *this; }
  friend int raw(const modint &self) { return self.v; }
  friend ostream &operator<<(ostream &os, const modint &self) {
    return os << raw(self);
  }
  modint &operator+=(const modint &rhs) {
    v += rhs.v;
    if (v >= umod) v -= umod;
    return *this;
  }
  modint &operator-=(const modint &rhs) {
    v -= rhs.v;
    if (v >= umod) v += umod;
    return *this;
  }
  modint &operator*=(const modint &rhs) {
    v = 1ull * v * rhs.v % umod;
    return *this;
  }
  modint &operator/=(const modint &rhs) {
    assert(rhs.v);
    return *this *= qpow(rhs, mod - 2);
  }
  template <class T, must_int<T> = 0>
  friend modint qpow(modint a, T b) {
    modint r = 1;
    for (; b; b >>= 1, a *= a)
      if (b & 1) r *= a;
    return r;
  }
  friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
  friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
  friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
  friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
  bool operator==(const modint &rhs) const { return v == rhs.v; }
  bool operator!=(const modint &rhs) const { return v != rhs.v; }
};
int glim(int n) { return 1 << (32 - __builtin_clz(n - 1)); }
int bitctz(int n) { return __builtin_ctz(n); }
template <class mint>
void ntt(vector<mint> &a, int op) {
  static vector<mint> wns;
  if (wns.empty()) {
    int exp = mint::mod - 1;
    while (exp % 2 == 0) wns.push_back(qpow(mint(3), exp >>= 1));
  }
  int n = a.size();
  for (int i = 1, r = 0; i < n; i++) {
    r ^= n - (1 << (bitctz(n) - bitctz(i) - 1));
    debug("i = %d, r = %d\n", i, r);
    if (r > i) swap(a[r], a[i]);
  }
  vector<mint> w(n);
  for (int k = 1, len = 2; len <= n; len <<= 1, k <<= 1) {
    mint wn = wns[bitctz(k)];
    for (int i = raw(w[0] = 1); i < k; i++) w[i] = w[i - 1] * wn;
    for (int i = 0; i < n; i += len) {
      for (int j = 0; j < k; j++) {
        mint x = a[i + j], y = a[i + j + k] * w[j];
        a[i + j] = x + y;
        a[i + j + k] = x - y;
      }
    }
  }
  if (op == -1) {
    reverse(a.begin() + 1, a.end());
    mint iz = mint(1) / n;
    for (int i = 0; i < n; i++) a[i] *= iz;
  }
}
template <class mint>
vector<mint> convolution(vector<mint> a, vector<mint> b) {
  int rlen = a.size() + b.size() - 1, len = glim(rlen);
  a.resize(len), ntt(a, 1);
  b.resize(len), ntt(b, 1);
  for (int i = 0; i < len; i++) a[i] *= b[i];
  ntt(a, -1), a.resize(rlen);
  return a;
}

主要的部分在这里,实现比较逆天,供参考。其中那些巨大数字由 python 程序生成:

#!/bin/env python3
n = 3
a = [998244353, 1004535809, 469762049]
M = a[0] * a[1] * a[2]
m = [M // a[i] for i in range(n)]
t = [pow(m[i], -1, a[i]) for i in range(n)] # 模意义下逆元
coe = [m[i] * t[i] % M for i in range(n)]
typ = "unsigned __int128"
prog = f"""
{typ} operator\"\"_ubi(const char* str) {{
  int len = strlen(str);
  {typ} x = 0;
  for (int i = 0; i < len; i++) x = x * 10 + str[i] - '0';
  return x;
}}
"""
print(prog)
print(*[f"({coe[i]}_ubi * raw(ret{i + 1}[i]) % {M}_ubi)" for i in range(n)], sep = " + ")

a = b = c = 1463
func = f"((401276874248923522479908641 * {a} % 471064322751194440790966273) +(185347017962817624218731910 * {b} % 471064322751194440790966273)  +(355504753290647734883291996 * {c} % 471064322751194440790966273) ) % M"
print(eval(func)) # 验算

注意要先彻底模完 \(M\) 再去模 \(p\)

__int128 operator""_ubi(const char *str) {
  int len = strlen(str);
  __int128 x = 0;
  for (int i = 0; i < len; i++) x = x * 10 + str[i] - '0';
  return x;
}
void print(__int128 x) {
  const LL base = 1e18;
#ifdef LOCAL
  if (x <= base)
    cerr << (LL)x << endl;
  else
    cerr << (LL)(x / base) << (LL)(x % base) << endl;
#endif
}
vector<int> convolution_fun(const vector<int> &a, const vector<int> &b, int p) {
  static constexpr unsigned mods[] = {998244353, 1004535809, 469762049};
  typedef modint<mods[0]> mint1;
  typedef modint<mods[1]> mint2;
  typedef modint<mods[2]> mint3;
  int rlen = a.size() + b.size() - 1, len = glim(rlen);
  vector<int> ret(rlen);
  auto ret1 = convolution(vector<mint1>(a.begin(), a.end()),
                          vector<mint1>(b.begin(), b.end()));
  auto ret2 = convolution(vector<mint2>(a.begin(), a.end()),
                          vector<mint2>(b.begin(), b.end()));
  auto ret3 = convolution(vector<mint3>(a.begin(), a.end()),
                          vector<mint3>(b.begin(), b.end()));
  for (int i = 0; i < rlen; i++) {
    ret[i] = ((401276874248923522479908641_ubi * raw(ret1[i]) %
               471064322751194440790966273_ubi) +
              (185347017962817624218731910_ubi * raw(ret2[i]) %
               471064322751194440790966273_ubi) +
              (355504753290647734883291996_ubi * raw(ret3[i]) %
               471064322751194440790966273_ubi)) %
             471064322751194440790966273_ubi % p;
  }
  return ret;
}
int main() {
#ifndef LOCAL
  cin.tie(nullptr)->sync_with_stdio(false);
#endif
  int n, m, p;
  cin >> n >> m >> p;
  vector<int> a(n + 1);
  for (int i = 0; i <= n; i++) cin >> a[i];
  vector<int> b(m + 1);
  for (int i = 0; i <= m; i++) cin >> b[i];
  for (int x : convolution_fun(a, b, p)) cout << x << " ";
  cout << endl;
  return 0;
}

dif-dif 实现

https://charleswu.site/archives/3065


#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
typedef long long LL;
template <class T>
using must_int = enable_if_t<is_integral<T>::value, int>;
template <unsigned umod>
struct modint {
  static constexpr int mod = umod;
  unsigned v;
  modint() : v(0) {}
  template <class T, must_int<T> = 0>
  modint(T x) {
    x %= mod;
    v = x < 0 ? x + mod : x;
  }
  modint operator+() const { return *this; }
  modint operator-() const { return modint() - *this; }
  friend int raw(const modint &self) { return self.v; }
  friend ostream &operator<<(ostream &os, const modint &self) {
    return os << raw(self);
  }
  modint &operator+=(const modint &rhs) {
    v += rhs.v;
    if (v >= umod) v -= umod;
    return *this;
  }
  modint &operator-=(const modint &rhs) {
    v -= rhs.v;
    if (v >= umod) v += umod;
    return *this;
  }
  modint &operator*=(const modint &rhs) {
    v = 1ull * v * rhs.v % umod;
    return *this;
  }
  modint &operator/=(const modint &rhs) {
    assert(rhs.v);
    return *this *= qpow(rhs, mod - 2);
  }
  template <class T, must_int<T> = 0>
  friend modint qpow(modint a, T b) {
    modint r = 1;
    for (; b; b >>= 1, a *= a)
      if (b & 1) r *= a;
    return r;
  }
  friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
  friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
  friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
  friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
  bool operator==(const modint &rhs) const { return v == rhs.v; }
  bool operator!=(const modint &rhs) const { return v != rhs.v; }
};
int glim(int n) { return 1 << (32 - __builtin_clz(n - 1)); }
int bitctz(int n) { return __builtin_ctz(n); }
template <class mint, int G>
struct NTT_env {
  vector<mint> w{1};
  void init(int n) {
    while (w.size() < n) {
      int m = w.size();
      mint wn = qpow(mint(G), (mint::mod - 1) / m >> 2);
      w.resize(m << 1);
      for (int i = m; i < m << 1; i++) w[i] = wn * w[i ^ m];
    }
  }
  void dif(vector<mint> &a) {
    int n = a.size();
    init(n);
    for (int len = n, k = n >> 1; k >= 1; len = k, k >>= 1) {
      for (int i = 0, t = 0; i < n; i += len, t++) {
        for (int j = 0; j < k; j++) {
          mint x = a[i + j], y = a[i + j + k] * w[t];
          a[i + j] = x + y, a[i + j + k] = x - y;
        }
      }
    }
  }
  void dit(vector<mint> &a) {
    int n = a.size();
    init(n);
    for (int k = 1, len = 2; len <= n; k = len, len <<= 1) {
      for (int i = 0, t = 0; i < n; i += len, t++) {
        for (int j = 0; j < k; j++) {
          mint x = a[i + j], y = a[i + j + k];
          a[i + j] = x + y, a[i + j + k] = (x - y) * w[t];
        }
      }
    }
    mint iv = mint(1) / n;
    for (int i = 0; i < n; i++) a[i] *= iv;
    reverse(a.begin() + 1, a.end());
  }
};
template <class mint>
vector<mint> convolution(vector<mint> a, vector<mint> b) {
  static NTT_env<mint, 3> ntt;
  int rlen = a.size() + b.size() - 1, len = glim(rlen);
  a.resize(len), ntt.dif(a);
  b.resize(len), ntt.dif(b);
  for (int i = 0; i < len; i++) a[i] *= b[i];
  ntt.dit(a), a.resize(rlen);
  return a;
}
__int128 operator""_ubi(const char *str) {
  int len = strlen(str);
  __int128 x = 0;
  for (int i = 0; i < len; i++) x = x * 10 + str[i] - '0';
  return x;
}
void print(__int128 x) {
  const LL base = 1e18;
#ifdef LOCAL
  if (x <= base)
    cerr << (LL)x << endl;
  else
    cerr << (LL)(x / base) << (LL)(x % base) << endl;
#endif
}
vector<int> convolution_fun(const vector<int> &a, const vector<int> &b, int p) {
  static constexpr unsigned mods[] = {5 << 25 | 1, 7 << 26 | 1, 119 << 23 | 1};
  typedef modint<mods[0]> mint1;
  typedef modint<mods[1]> mint2;
  typedef modint<mods[2]> mint3;
  int rlen = a.size() + b.size() - 1, len = glim(rlen);
  vector<int> ret(rlen);
  auto ret1 = convolution(vector<mint1>(a.begin(), a.end()),
                          vector<mint1>(b.begin(), b.end()));
  auto ret2 = convolution(vector<mint2>(a.begin(), a.end()),
                          vector<mint2>(b.begin(), b.end()));
  auto ret3 = convolution(vector<mint3>(a.begin(), a.end()),
                          vector<mint3>(b.begin(), b.end()));
  for (int i = 0; i < rlen; i++) {
    ret[i] = (13862981345800242111649459_ubi * raw(ret1[i]) +
              19425833427644834021761124_ubi * raw(ret2[i]) +
              45385811546391130584320235_ubi * raw(ret3[i])) %
             78674626319836206717730817_ubi % p;
  }
  return ret;
}
int main() {
#ifndef LOCAL
  cin.tie(nullptr)->sync_with_stdio(false);
#endif
  int n, m, p;
  cin >> n >> m >> p;
  vector<int> a(n + 1);
  for (int i = 0; i <= n; i++) cin >> a[i];
  vector<int> b(m + 1);
  for (int i = 0; i <= m; i++) cin >> b[i];
  for (int x : convolution_fun(a, b, p)) cout << x << " ";
  cout << endl;
  return 0;
}

posted @ 2024-04-10 21:34  caijianhong  阅读(70)  评论(0编辑  收藏  举报