NTT 学习笔记
NTT
前置知识:FFT
NTT,中文“快速数论变换”,是 FFT 在数论领域上的实现,比 FFT 更快,应用更广。
对于 FFT,因为其涉及到复数操作,对于某些需要取模的题不再适用。并且因为需要求正弦与余弦,使用时难以避免精度误差。这时就需要用到 NTT 来解决问题了。
我们知道 FFT 的实现是在复平面上找到了
原根
我们可以使用原根找到这
假设要取模的奇素数为
设
为了验证原根与单位根的相似性质,不妨令
-
乘法:
。对于 显然成立。 -
周期性:
。由于 ,所以 ,成立。 -
互异性:
互不相同。根据原根的性质可以很容易得到 不同。 -
消去引理:
,有 。证明: 。 -
折半引理:
且 为偶数,有 。首先 ,因为有周期性 ,再由消去引理得到 。
可以发现,因为
模数
不是所有模数都可以使用 NTT(至少不能用朴素的 NTT)。
在 FFT 中,为了方便使用折半引理,我们强制将
INTT
将
代码:
template<class T>
T power(T a, i64 b) {
T res = 1;
for (; b; b >>= 1, a *= a) {
if (b & 1)
res *= a;
}
return res;
}
template<const i64 P>
class ModInt {
public:
i64 x;
static i64 Mod;
ModInt() : x{0} {}
ModInt(int _x) : x{(_x % getMod() + getMod()) % getMod()} {}
ModInt(i64 _x) : x{(_x % getMod() + getMod()) % getMod()} {}
static void setMod(i64 Mod_) {
Mod = Mod_;
}
static i64 getMod() {
return !P ? Mod : P;
}
explicit constexpr operator int() const {
return x;
}
ModInt &operator += (ModInt a) & {
x = x + a.x >= getMod() ? x + a.x - getMod() : x + a.x;
return (*this);
}
ModInt &operator -= (ModInt a) & {
x = x - a.x < 0 ? x - a.x + getMod() : x - a.x;
return (*this);
}
ModInt &operator *= (ModInt a) & {
(x *= a.x) %= getMod();
return (*this);
}
constexpr ModInt inv() {
return power((*this), getMod() - 2);
}
ModInt &operator /= (ModInt a) & {
return (*this) *= a.inv();
}
friend ModInt operator + (ModInt lhs, ModInt rhs) {
return lhs += rhs;
}
friend ModInt operator - (ModInt lhs, ModInt rhs) {
return lhs -= rhs;
}
friend ModInt operator * (ModInt lhs, ModInt rhs) {
return lhs *= rhs;
}
friend ModInt operator / (ModInt lhs, ModInt rhs) {
return lhs /= rhs;
}
friend std::istream &operator >> (std::istream &is, ModInt &p) {
return is >> p.x;
}
friend std::ostream &operator << (std::ostream &os, ModInt p) {
return os << p.x;
}
int operator !() {
return !x;
}
friend bool operator == (ModInt lhs, ModInt rhs) {
return lhs.x == rhs.x;
}
friend bool operator != (ModInt lhs, ModInt rhs) {
return lhs.x != rhs.x;
}
ModInt operator -() {
return ModInt(getMod() - x);
}
ModInt &operator ++() & {
++x;
return *this;
}
ModInt operator ++(int) {
ModInt temp = *this;
++*this;
return temp;
}
} ;
template<>
i64 ModInt<0>::Mod = 998244353;
const int P = 167772161, g = 3;
using Z = ModInt<P>;
struct Comb {
int n;
vector<Z> _fac;
vector<Z> _invfac;
vector<Z> _inv;
Comb() : n{0}, _fac{1}, _invfac{1}, _inv{0} {}
Comb(int n) : Comb() {init(n);}
void init(int m) {
m = min<int>(m, Z::getMod() - 1);
if (m <= n) return;
_fac.resize(m + 1);
_invfac.resize(m + 1);
_inv.resize(m + 1);
for (int i = n + 1; i <= m; i++) _fac[i] = _fac[i - 1] * i;
_invfac[m] = _fac[m].inv();
for (int i = m; i > n; i--) {
_invfac[i - 1] = _invfac[i] * i;
_inv[i] = _invfac[i] * _fac[i - 1];
} n = m;
}
Z fac(int m) {if (m > n) init(2 * m); return _fac[m];}
Z invfac(int m) {if (m > n) init(2 * m); return _invfac[m];}
Z inv(int m) {if (m > n) init(2 * m); return _inv[m];}
Z binom(int n, int m) {return n < m || m < 0 ? 0 : fac(n) * invfac(m) * invfac(n - m);}
} comb;
std::vector<int> rev;
void ExtendRev(int n) {
int m = rev.size();
rev.resize(n);
int s = __builtin_ctz(n) - 1;
for (int i = 0; i < n; ++i)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << s);
}
template<const int P, const int g>
struct Poly : public std::vector<ModInt<P>> {
using V = ModInt<P>;
int invg = V(g).inv().x;
Poly() : std::vector<V>() {}
Poly(int n) : std::vector<V>(n, V {}) {}
Poly(int n, V x) : std::vector<V>(n, x) {}
Poly(std::vector<V> _a) : std::vector<V>(_a) {}
Poly(std::initializer_list<V> _a) : std::vector<V>(_a) {}
template<class InputIt, class = std::_RequireInputIter<InputIt>>
Poly(InputIt first, InputIt last) : std::vector<V>(first, last) {}
std::vector<V> trunc(int n) {
auto f = *this;
f.resize(n);
return f;
}
void dft(int Root = g) {
int n = this->size();
ExtendRev(n);
for (int i = 0; i < n; ++i)
if (i < rev[i]) {
std::swap((*this)[i], (*this)[rev[i]]);
}
for (int k = 1; k < n; k <<= 1) {
V wn = power(V(Root), (P - 1) / (2 * k));
for (int i = 0; i < n; i += 2 * k) {
V w = 1;
for (int j = 0; j < k; ++j, w *= wn) {
auto x = (*this)[i + j], y = (*this)[i + j + k] * w;
(*this)[i + j] = x + y;
(*this)[i + j + k] = x - y;
}
}
}
}
void idft() {
dft(invg);
V invn = V(int(this->size())).inv().x;
for (auto &x : *this)
x *= invn;
}
friend Poly operator * (Poly a, V b) {
for (auto &x : a)
x *= b;
return a;
}
friend Poly operator * (Poly a, Poly b) {
int m = a.size() + b.size() - 1, n = a.size() + b.size();
for (; n != (n & -n); ++n) ;
a.resize(n);
b.resize(n);
a.dft();
b.dft();
for (int i = 0; i < n; ++i)
a[i] *= b[i];
a.idft();
return a.trunc(m);
}
friend Poly operator + (Poly a, Poly b) {
if (a.size() < b.size())
std::swap(a, b);
for (int i = 0; i < b.size(); ++i)
a[i] += b[i];
return a;
}
Poly operator -() {
Poly a = *this;
for (int i = 0; i < a.size(); ++i)
a[i] = -a[i];
return a;
}
friend Poly operator - (Poly a, Poly b) {
return a + -b;
}
} ;
using Pol = Poly<P, g>;
UPD on 2024/10/29
之前的板子写得很憨,重新封装了一个,目前只有多项式除法和常系数齐次线性递推。
template<typename T>
constexpr T power(T a, i64 b) {
T res = 1;
for (; b; b /= 2, a = a * a) {
if (b & 1) {
res = res * a;
}
}
return res;
}
template<int P>
class ModInt {
public:
int x;
constexpr ModInt() : x{0} {}
constexpr ModInt(int n) : x{(n % getMod() + getMod()) % getMod()} {}
constexpr ModInt(i64 n) : x{(n % getMod() + getMod()) % getMod()} {}
static int Mod;
constexpr static void setMod(int _Mod) {
Mod = _Mod;
}
constexpr static int getMod() {
return P > 0 ? P : Mod;
}
constexpr ModInt &operator+=(ModInt k) & {
x = (x + k.x >= getMod() ? x + k.x - getMod() : x + k.x);
return *this;
}
constexpr ModInt &operator-=(ModInt k) & {
x = (x - k.x < 0 ? x - k.x + getMod() : x - k.x);
return *this;
}
constexpr ModInt &operator*=(ModInt k) & {
x = 1LL * x * k.x % getMod();
return *this;
}
friend constexpr ModInt operator+(ModInt lhs, ModInt rhs) {
return lhs += rhs;
}
friend constexpr ModInt operator-(ModInt lhs, ModInt rhs) {
return lhs -= rhs;
}
friend constexpr ModInt operator*(ModInt lhs, ModInt rhs) {
return lhs *= rhs;
}
friend constexpr ModInt operator/(ModInt lhs, ModInt rhs) {
return lhs /= rhs;
}
constexpr ModInt inv() const {
return power(*this, getMod() - 2);
}
constexpr ModInt &operator/=(ModInt k) & {
return (*this) *= k.inv();
}
friend constexpr std::istream &operator>>(std::istream &is, ModInt &k) {
i64 val;
is >> val;
k = val;
return is;
}
friend constexpr std::ostream &operator<<(std::ostream &os, ModInt k) {
return os << k.x;
}
friend constexpr bool operator==(ModInt lhs, ModInt rhs) {
return lhs.x == rhs.x;
}
friend constexpr bool operator!=(ModInt lhs, ModInt rhs) {
return lhs.x != rhs.x;
}
constexpr bool operator!() {
return !x;
}
constexpr ModInt &operator++() & {
return (*this) += 1;
}
constexpr ModInt &operator++(int) & {
ModInt temp = *this;
(*this) += 1;
return temp;
}
constexpr ModInt &operator--() & {
return (*this) -= 1;
}
constexpr ModInt &operator--(int) & {
ModInt temp = *this;
(*this) -= 1;
return temp;
}
friend constexpr ModInt operator-(const ModInt& a) {
return a.getMod() - a.x;
}
} ;
template<>
int ModInt<0>::Mod = 1E9 + 7;
constexpr int P = 998244353;
using Z = ModInt<P>;
template<int P>
constexpr ModInt<P> findPrimitiveRoot() {
int k = __builtin_ctz(P - 1);
ModInt<P> i = 2;
while (true) {
if (power(i, (P - 1) / 2) != 1) {
break;
}
i += 1;
}
return power(i, (P - 1) >> k);
}
template<int P>
constexpr ModInt<P> primitiveRoot = findPrimitiveRoot<P>();
template<int P>
std::vector<ModInt<P>> roots = {0, 1};
std::vector<int> rev;
constexpr void extendRev(int n) {
if (rev.size() == n) {
return ;
}
rev.assign(n, 0);
int s = __builtin_ctz(n) - 1;
for (int i = 0; i < n; ++i) {
rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
}
}
template<int P>
constexpr void dft(std::vector<ModInt<P>> &a) {
const int n = a.size();
extendRev(n);
for (int i = 0; i < n; ++i) {
if (rev[i] < i) {
std::swap(a[i], a[rev[i]]);
}
}
if (roots<P>.size() < n) {
int k = __builtin_ctz(roots<P>.size());
roots<P>.resize(n);
while ((1 << k) < n) {
auto e = power(primitiveRoot<P>, 1 << (__builtin_ctz(P - 1) - k - 1));
for (int i = 1 << (k - 1); i < (1 << k); i++) {
roots<P>[2 * i] = roots<P>[i];
roots<P>[2 * i + 1] = roots<P>[i] * e;
}
k++;
}
}
for (int k = 1; k < n; k *= 2) {
for (int i = 0; i < n; i += 2 * k) {
for (int j = 0; j < k; j++) {
ModInt<P> u = a[i + j];
ModInt<P> v = a[i + j + k] * roots<P>[k + j];
a[i + j] = u + v;
a[i + j + k] = u - v;
}
}
}
}
template<int P>
constexpr void idft(std::vector<ModInt<P>> &a) {
const int n = a.size();
std::reverse(a.begin() + 1, a.end());
dft(a);
ModInt<P> inv = (1 - P) / n;
for (int i = 0; i < n; i++) {
a[i] *= inv;
}
}
template<int P>
struct Poly : public std::vector<ModInt<P>> {
using Z = ModInt<P>;
constexpr Poly() : std::vector<Z>() {}
constexpr Poly(int n) : std::vector<Z>(n) {}
constexpr Poly(const std::vector<Z> &a) : std::vector<Z>(a) {}
constexpr Poly(const std::initializer_list<Z> &a) : std::vector<Z>(a) {}
template<class InputIt, class = std::_RequireInputIter<InputIt>>
constexpr Poly(InputIt first, InputIt last) : std::vector<Z>(first, last) {}
template<class F>
constexpr Poly(int n, F f) : std::vector<Z>(n) {
for (int i = 0; i < n; ++i) {
(*this)[i] = f[i];
}
}
constexpr Poly shift(int k) const {
const int n = this->size();
Poly res(n);
for (int i = 0; i < n; ++i) {
res[(i + k) % n] = (*this)[i];
}
return res;
}
constexpr Poly trunc(int size) const {
auto f = *this;
f.resize(size);
return f;
}
constexpr friend Poly operator*(Poly a, Z val) {
for (auto& x : a) x *= val;
return a;
}
constexpr friend Poly operator*(Z val, Poly a) {
return a * val;
}
constexpr friend Poly operator/(Poly a, Z val) {
for (auto& x : a) x /= val;
return a;
}
constexpr friend Poly operator*(Poly a, Poly b) {
int m = a.size() + b.size() - 1, n = m;
for (; n != (n & -n); n += (n & -n)) ;
a.resize(n);
b.resize(n);
dft(a);
dft(b);
for (int i = 0; i < n; ++i) {
a[i] *= b[i];
}
idft(a);
return a.trunc(m);
}
constexpr friend Poly operator-(const Poly &a) {
auto res = a;
for (int i = 0; i < a.size(); ++i) {
res[i] = -res[i];
}
return res;
}
constexpr friend Poly operator-(Poly a, Poly b) {
Poly res(std::max(a.size(), b.size()));
for (int i = 0; i < a.size(); ++i) {
res[i] = a[i];
}
for (int i = 0; i < b.size(); ++i) {
res[i] -= b[i];
}
return res;
}
constexpr friend Poly operator+(Poly a, Poly b) {
Poly res(std::max(a.size(), b.size()));
for (int i = 0; i < a.size(); ++i) {
res[i] += a[i];
}
for (int i = 0; i < a.size(); ++i) {
res[i] += b[i];
}
return res;
}
constexpr Poly &operator-=(Poly a) {
return (*this) = (*this) - a;
}
constexpr Poly &operator+=(Poly a) {
return (*this) = (*this) + a;
}
constexpr Poly &operator*=(Poly a) {
return (*this) = (*this) * a;
}
constexpr Poly &operator*=(Z a) {
return (*this) = (*this) * a;
}
constexpr Poly &operator/=(Z a) {
return (*this) = (*this) / a;
}
constexpr Poly inv(int m) const {
Poly x{(*this)[0].inv()};
int k = 1;
while (k < m) {
k *= 2;
x = (x * (Poly{2} - trunc(k) * x)).trunc(k);
}
return x.trunc(m);
}
constexpr Poly inv() const {
return inv(this->size());
}
constexpr friend Poly operator/(Poly a, Poly b) {
int n = a.size(), m = b.size();
if (n < m) {
n = m;
a.resize(n);
}
int l = n - m + 1;
std::reverse(a.begin(), a.end());
std::reverse(b.begin(), b.end());
Poly res = (a.trunc(l) * (b.trunc(l)).inv()).trunc(l);
std::reverse(res.begin(), res.end());
return res;
}
constexpr friend Poly operator%(Poly a, Poly b) {
return (a - b * (a / b)).trunc(b.size() - 1);
}
constexpr Poly &operator/=(Poly b) {
return (*this) = (*this) / b;
}
constexpr Poly &operator%=(Poly b) {
return (*this) = (*this) % b;
}
} ;
template<int P>
constexpr ModInt<P> linearRecurrence(Poly<P> f, Poly<P> a, i64 n) {
f[0] = -1;
f = -f;
std::reverse(f.begin(), f.end());
Poly<P> res(a.size()), base(a.size());
res[0] = 1; base[1] = 1;
for (; n; n /= 2, base = base * base % f) {
if (n & 1) {
res = res * base % f;
}
}
Z ans = 0;
for (int i = 0; i < a.size(); ++i) {
ans += res[i] * a[i];
}
return ans;
}
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 按钮权限的设计及实现
· 25岁的心里话