NTT 学习笔记

NTT

前置知识:FFT

NTT,中文“快速数论变换”,是 FFT 在数论领域上的实现,比 FFT 更快,应用更广。

对于 FFT,因为其涉及到复数操作,对于某些需要取模的题不再适用。并且因为需要求正弦与余弦,使用时难以避免精度误差。这时就需要用到 NTT 来解决问题了。

我们知道 FFT 的实现是在复平面上找到了 n 个不同的点 ωn0,ωn1,,ωnn1。由于现在需要解决取模的问题,因此我们考虑在剩余系中找 n等价的点。

原根

我们可以使用原根找到这 n 个数。由于原根通常较小,所以可以暴力求原根。

假设要取模的奇素数为 p,再设 p=qn+1n(p1)

g 为模 p 意义下的原根,根据费马小定理和欧拉定理得:gnqgφ(p)1(modp)。其中 φ 为欧拉函数,由于 p 是奇数,故 φ(p)=p1

为了验证原根与单位根的相似性质,不妨令 xni=giq=(gp1n)i,看做和 ωni 等价。我们接下来验证 xni 是否与 ωni 等价。

  • 乘法:ωni×ωnj=ωni+j对于 xni 显然成立。

  • 周期性:ωni=ωni+n由于 xnn=gnq1(modp),所以 xni+n=xni×xnn=xni,成立。

  • 互异性:ωn0,ωn1,,ωnn1 互不相同。根据原根的性质可以很容易得到 xn0,xn1,,xnn1 不同。

  • 消去引理:nN,kN,dN,有 ωdndk=ωnk证明:xdndk=gdkp1dn=(gp1n)k=xnk

  • 折半引理:nN,kNn 为偶数,有 (ωnk+n/2)2=ωn/2k首先 (xnk+n/2)2=xn2k+n,因为有周期性 xn2k+n=xn2k,再由消去引理得到 xn2k=xn/2k

可以发现,因为 xni 同样满足消去引理与折半引理,所以后面对多项式的推导中 xniωni 完全等价。

模数

不是所有模数都可以使用 NTT(至少不能用朴素的 NTT)。

FFT 中,为了方便使用折半引理,我们强制将 n2log2n,这样的话我们的模数必须满足 p=q2k+1,其中 q 是奇数,必须有 n2k。后面可能会收录一些 NTT 常用模数,平常可以用 998244353

INTT

xni 换成 1xni 后用逆元即可,除以 n 时也用逆元。

代码:

template<class T>
T power(T a, i64 b) {
    T res = 1;
    for (; b; b >>= 1, a *= a) {
        if (b & 1)
            res *= a;
    }
    return res;
}
template<const i64 P>
class ModInt {
public:
    i64 x;
    static i64 Mod;

    ModInt() : x{0} {}
    ModInt(int _x) : x{(_x % getMod() + getMod()) % getMod()} {}
    ModInt(i64 _x) : x{(_x % getMod() + getMod()) % getMod()} {}

    static void setMod(i64 Mod_) {
        Mod = Mod_;
    }
    static i64 getMod() {
        return !P ? Mod : P;
    }
    explicit constexpr operator int() const {
        return x;
    }

    ModInt &operator += (ModInt a) & {
        x = x + a.x >= getMod() ? x + a.x - getMod() : x + a.x;
        return (*this);
    }
    ModInt &operator -= (ModInt a) & {
        x = x - a.x < 0 ? x - a.x + getMod() : x - a.x;
        return (*this);
    }
    ModInt &operator *= (ModInt a) & {
        (x *= a.x) %= getMod();
        return (*this);
    }
    constexpr ModInt inv() {
        return power((*this), getMod() - 2);
    }
    ModInt &operator /= (ModInt a) & {
        return (*this) *= a.inv();
    }
    friend ModInt operator + (ModInt lhs, ModInt rhs) {
        return lhs += rhs;
    }
    friend ModInt operator - (ModInt lhs, ModInt rhs) {
        return lhs -= rhs;
    }
    friend ModInt operator * (ModInt lhs, ModInt rhs) {
        return lhs *= rhs;
    }
    friend ModInt operator / (ModInt lhs, ModInt rhs) {
        return lhs /= rhs;
    }

    friend std::istream &operator >> (std::istream &is, ModInt &p) {
        return is >> p.x;
    }
    friend std::ostream &operator << (std::ostream &os, ModInt p) {
        return os << p.x;
    }
    int operator !() {
        return !x;
    }
    friend bool operator == (ModInt lhs, ModInt rhs) {
        return lhs.x == rhs.x;
    }
    friend bool operator != (ModInt lhs, ModInt rhs) {
        return lhs.x != rhs.x;
    }
    ModInt operator -() {
        return ModInt(getMod() - x);
    }
    ModInt &operator ++() & {
        ++x;
        return *this;
    }
    ModInt operator ++(int) {
        ModInt temp = *this;
        ++*this;
        return temp;
    }
} ;
template<>
i64 ModInt<0>::Mod = 998244353;
const int P = 167772161, g = 3;
using Z = ModInt<P>;
struct Comb {
    int n;
    vector<Z> _fac;
    vector<Z> _invfac;
    vector<Z> _inv;
    Comb() : n{0}, _fac{1}, _invfac{1}, _inv{0} {}
    Comb(int n) : Comb() {init(n);}
    void init(int m) {
        m = min<int>(m, Z::getMod() - 1);
        if (m <= n) return;
        _fac.resize(m + 1);
        _invfac.resize(m + 1);
        _inv.resize(m + 1);
        for (int i = n + 1; i <= m; i++) _fac[i] = _fac[i - 1] * i;
        _invfac[m] = _fac[m].inv();
        for (int i = m; i > n; i--) {
            _invfac[i - 1] = _invfac[i] * i;
            _inv[i] = _invfac[i] * _fac[i - 1];
        } n = m;
    }
    Z fac(int m) {if (m > n) init(2 * m); return _fac[m];}
    Z invfac(int m) {if (m > n) init(2 * m); return _invfac[m];}
    Z inv(int m) {if (m > n) init(2 * m); return _inv[m];}
    Z binom(int n, int m) {return n < m || m < 0 ? 0 : fac(n) * invfac(m) * invfac(n - m);}
} comb;
std::vector<int> rev;
void ExtendRev(int n) {
    int m = rev.size();
    rev.resize(n);
    int s = __builtin_ctz(n) - 1;
    for (int i = 0; i < n; ++i)
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << s);
}

template<const int P, const int g>
struct Poly : public std::vector<ModInt<P>> {
    using V = ModInt<P>;
    int invg = V(g).inv().x;

    Poly() : std::vector<V>() {}
    Poly(int n) : std::vector<V>(n, V {}) {}
    Poly(int n, V x) : std::vector<V>(n, x) {}
    Poly(std::vector<V> _a) : std::vector<V>(_a) {}
    Poly(std::initializer_list<V> _a) : std::vector<V>(_a) {}
    template<class InputIt, class = std::_RequireInputIter<InputIt>>
    Poly(InputIt first, InputIt last) : std::vector<V>(first, last) {}

    std::vector<V> trunc(int n) {
        auto f = *this;
        f.resize(n);
        return f;
    }
    void dft(int Root = g) {
        int n = this->size();
        ExtendRev(n);
        for (int i = 0; i < n; ++i)
            if (i < rev[i]) {
                std::swap((*this)[i], (*this)[rev[i]]);
            }

        for (int k = 1; k < n; k <<= 1) {
            V wn = power(V(Root), (P - 1) / (2 * k));
            for (int i = 0; i < n; i += 2 * k) {
                V w = 1;
                for (int j = 0; j < k; ++j, w *= wn) {
                    auto x = (*this)[i + j], y = (*this)[i + j + k] * w;
                    (*this)[i + j] = x + y;
                    (*this)[i + j + k] = x - y;
                }
            }
        }
    }
    void idft() {
        dft(invg);
        V invn = V(int(this->size())).inv().x;
        for (auto &x : *this)
            x *= invn;
    }

    friend Poly operator * (Poly a, V b) {
        for (auto &x : a)
            x *= b;
        return a;
    }
    friend Poly operator * (Poly a, Poly b) {
        int m = a.size() + b.size() - 1, n = a.size() + b.size();
        for (; n != (n & -n); ++n) ;
        a.resize(n);
        b.resize(n);

        a.dft();
        b.dft();
        for (int i = 0; i < n; ++i) 
            a[i] *= b[i];
        a.idft();

        return a.trunc(m);
    }
    friend Poly operator + (Poly a, Poly b) {
        if (a.size() < b.size())
            std::swap(a, b);
        for (int i = 0; i < b.size(); ++i)
            a[i] += b[i];
        return a;
    }
    Poly operator -() {
        Poly a = *this;
        for (int i = 0; i < a.size(); ++i)
            a[i] = -a[i];
        return a;
    }
    friend Poly operator - (Poly a, Poly b) {
        return a + -b;
    }
} ;
using Pol = Poly<P, g>;

UPD on 2024/10/29

之前的板子写得很憨,重新封装了一个,目前只有多项式除法和常系数齐次线性递推。

template<typename T>
constexpr T power(T a, i64 b) {
	T res = 1;
	for (; b; b /= 2, a = a * a) {
		if (b & 1) {
			res = res * a;
		}
	}
	return res;
}

template<int P>
class ModInt {
public:
	int x;

	constexpr ModInt() : x{0} {}
	constexpr ModInt(int n) : x{(n % getMod() + getMod()) % getMod()} {}
	constexpr ModInt(i64 n) : x{(n % getMod() + getMod()) % getMod()} {}

	static int Mod;
	constexpr static void setMod(int _Mod) {
		Mod = _Mod;
	}
	constexpr static int getMod() {
		return P > 0 ? P : Mod;
	}

	constexpr ModInt &operator+=(ModInt k) & {
		x = (x + k.x >= getMod() ? x + k.x - getMod() : x + k.x);
		return *this;
	}
	constexpr ModInt &operator-=(ModInt k) & {
		x = (x - k.x < 0 ? x - k.x + getMod() : x - k.x);
		return *this;
	}
	constexpr ModInt &operator*=(ModInt k) & {
		x = 1LL * x * k.x % getMod();
		return *this;
	}

	friend constexpr ModInt operator+(ModInt lhs, ModInt rhs) {
		return lhs += rhs;
	}
	friend constexpr ModInt operator-(ModInt lhs, ModInt rhs) {
		return lhs -= rhs;
	}
	friend constexpr ModInt operator*(ModInt lhs, ModInt rhs) {
		return lhs *= rhs;
	}
	friend constexpr ModInt operator/(ModInt lhs, ModInt rhs) {
		return lhs /= rhs;
	}

	constexpr ModInt inv() const {
		return power(*this, getMod() - 2);
	}
	constexpr ModInt &operator/=(ModInt k) & {
		return (*this) *= k.inv();
	}

	friend constexpr std::istream &operator>>(std::istream &is, ModInt &k) {
		i64 val;
		is >> val;
		k = val;
		return is;
	}
	friend constexpr std::ostream &operator<<(std::ostream &os, ModInt k) {
		return os << k.x;
	}

	friend constexpr bool operator==(ModInt lhs, ModInt rhs) {
		return lhs.x == rhs.x;
	}
	friend constexpr bool operator!=(ModInt lhs, ModInt rhs) {
		return lhs.x != rhs.x;
	}
	constexpr bool operator!() {
		return !x;
	}

	constexpr ModInt &operator++() & {
		return (*this) += 1;
	}
	constexpr ModInt &operator++(int) & {
		ModInt temp = *this;
		(*this) += 1;
		return temp;
	}
	constexpr ModInt &operator--() & {
		return (*this) -= 1;
	}
	constexpr ModInt &operator--(int) & {
		ModInt temp = *this;
		(*this) -= 1;
		return temp;
	}

	friend constexpr ModInt operator-(const ModInt& a) {
		return a.getMod() - a.x;
	}
} ;

template<>
int ModInt<0>::Mod = 1E9 + 7;

constexpr int P = 998244353;
using Z = ModInt<P>;

template<int P>
constexpr ModInt<P> findPrimitiveRoot() {
	int k = __builtin_ctz(P - 1);
	ModInt<P> i = 2;

	while (true) {
		if (power(i, (P - 1) / 2) != 1) {
			break;
		}
		i += 1;
	}

	return power(i, (P - 1) >> k);
}

template<int P>
constexpr ModInt<P> primitiveRoot = findPrimitiveRoot<P>();

template<int P>
std::vector<ModInt<P>> roots = {0, 1};

std::vector<int> rev;
constexpr void extendRev(int n) {
	if (rev.size() == n) {
		return ;
	}
	rev.assign(n, 0);
	int s = __builtin_ctz(n) - 1;
	for (int i = 0; i < n; ++i) {
		rev[i] = rev[i >> 1] >> 1 | (i & 1) << s;
	}
}

template<int P>
constexpr void dft(std::vector<ModInt<P>> &a) {
	const int n = a.size();

	extendRev(n);

	for (int i = 0; i < n; ++i) {
		if (rev[i] < i) {
			std::swap(a[i], a[rev[i]]);
		}
	}

	if (roots<P>.size() < n) {
		int k = __builtin_ctz(roots<P>.size());
		roots<P>.resize(n);
		while ((1 << k) < n) {
			auto e = power(primitiveRoot<P>, 1 << (__builtin_ctz(P - 1) - k - 1));
			for (int i = 1 << (k - 1); i < (1 << k); i++) {
				roots<P>[2 * i] = roots<P>[i];
				roots<P>[2 * i + 1] = roots<P>[i] * e;
			}
			k++;
		}
	}

	for (int k = 1; k < n; k *= 2) {
		for (int i = 0; i < n; i += 2 * k) {
			for (int j = 0; j < k; j++) {
				ModInt<P> u = a[i + j];
				ModInt<P> v = a[i + j + k] * roots<P>[k + j];
				a[i + j] = u + v;
				a[i + j + k] = u - v;
			}
		}
	}
}

template<int P>
constexpr void idft(std::vector<ModInt<P>> &a) {
	const int n = a.size();
	std::reverse(a.begin() + 1, a.end());
	dft(a);
	ModInt<P> inv = (1 - P) / n;
	for (int i = 0; i < n; i++) {
		a[i] *= inv;
	}
}

template<int P>
struct Poly : public std::vector<ModInt<P>> {
	using Z = ModInt<P>;

	constexpr Poly() : std::vector<Z>() {}

	constexpr Poly(int n) : std::vector<Z>(n) {}
	constexpr Poly(const std::vector<Z> &a) : std::vector<Z>(a) {}
	constexpr Poly(const std::initializer_list<Z> &a) : std::vector<Z>(a) {}

	template<class InputIt, class = std::_RequireInputIter<InputIt>>
	constexpr Poly(InputIt first, InputIt last) : std::vector<Z>(first, last) {}

	template<class F>
	constexpr Poly(int n, F f) : std::vector<Z>(n) {
		for (int i = 0; i < n; ++i) {
			(*this)[i] = f[i];
		}
	}

	constexpr Poly shift(int k) const {
		const int n = this->size();
		Poly res(n);
		for (int i = 0; i < n; ++i) {
			res[(i + k) % n] = (*this)[i];
		}
		return res;
	}

	constexpr Poly trunc(int size) const {
		auto f = *this;
		f.resize(size);
		return f;
	}

	constexpr friend Poly operator*(Poly a, Z val) {
		for (auto& x : a) x *= val;
		return a;
	}
	constexpr friend Poly operator*(Z val, Poly a) {
		return a * val;
	}
	constexpr friend Poly operator/(Poly a, Z val) {
		for (auto& x : a) x /= val;
		return a;
	}
	constexpr friend Poly operator*(Poly a, Poly b) {
		int m = a.size() + b.size() - 1, n = m;
		for (; n != (n & -n); n += (n & -n)) ;

		a.resize(n);
		b.resize(n);

		dft(a);
		dft(b);
		for (int i = 0; i < n; ++i) {
			a[i] *= b[i];
		}
		idft(a);

		return a.trunc(m);
	}
	constexpr friend Poly operator-(const Poly &a) {
		auto res = a;
		for (int i = 0; i < a.size(); ++i) {
			res[i] = -res[i];
		}
		return res;
	}
	constexpr friend Poly operator-(Poly a, Poly b) {
		Poly res(std::max(a.size(), b.size()));
		for (int i = 0; i < a.size(); ++i) {
			res[i] = a[i];
		}
		for (int i = 0; i < b.size(); ++i) {
			res[i] -= b[i];
		}
		return res;
	}
	constexpr friend Poly operator+(Poly a, Poly b) {
		Poly res(std::max(a.size(), b.size()));
		for (int i = 0; i < a.size(); ++i) {
			res[i] += a[i];
		}
		for (int i = 0; i < a.size(); ++i) {
			res[i] += b[i];
		}
		return res;
	}
	constexpr Poly &operator-=(Poly a) {
		return (*this) = (*this) - a;
	}
	constexpr Poly &operator+=(Poly a) {
		return (*this) = (*this) + a;
	}
	constexpr Poly &operator*=(Poly a) {
		return (*this) = (*this) * a;
	}
	constexpr Poly &operator*=(Z a) {
		return (*this) = (*this) * a;
	}
	constexpr Poly &operator/=(Z a) {
		return (*this) = (*this) / a;
	}

	constexpr Poly inv(int m) const {
		Poly x{(*this)[0].inv()};
		int k = 1;
		while (k < m) {
			k *= 2;
			x = (x * (Poly{2} - trunc(k) * x)).trunc(k);
		}
		return x.trunc(m);
	}
	constexpr Poly inv() const {
		return inv(this->size());
	}
	constexpr friend Poly operator/(Poly a, Poly b) {
		int n = a.size(), m = b.size();

		if (n < m) {
			n = m;
			a.resize(n);
		}
		int l = n - m + 1;

		std::reverse(a.begin(), a.end());
		std::reverse(b.begin(), b.end());

		Poly res = (a.trunc(l) * (b.trunc(l)).inv()).trunc(l);
		std::reverse(res.begin(), res.end());
		return res;
	}
	constexpr friend Poly operator%(Poly a, Poly b) {
		return (a - b * (a / b)).trunc(b.size() - 1);
	}
	constexpr Poly &operator/=(Poly b) {
		return (*this) = (*this) / b;
	}
	constexpr Poly &operator%=(Poly b) {
		return (*this) = (*this) % b;
	}
} ;
template<int P>
constexpr ModInt<P> linearRecurrence(Poly<P> f, Poly<P> a, i64 n) {
	f[0] = -1;
	f = -f;
	std::reverse(f.begin(), f.end());

	Poly<P> res(a.size()), base(a.size());
	res[0] = 1; base[1] = 1;
	for (; n; n /= 2, base = base * base % f) {
		if (n & 1) {
			res = res * base % f;
		}
	}

	Z ans = 0;
	for (int i = 0; i < a.size(); ++i) {
		ans += res[i] * a[i];
	}
	return ans;
}

作者:CTHOOH

出处:https://www.cnblogs.com/CTHOOH/p/18187174

版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。

posted @   CTHOOH  阅读(17)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· 分享一个免费、快速、无限量使用的满血 DeepSeek R1 模型,支持深度思考和联网搜索!
· 基于 Docker 搭建 FRP 内网穿透开源项目(很简单哒)
· ollama系列01:轻松3步本地部署deepseek,普通电脑可用
· 按钮权限的设计及实现
· 25岁的心里话
more_horiz
keyboard_arrow_up dark_mode palette
选择主题
点击右上角即可分享
微信分享提示