【模板】modint

latest version: https://raw.githubusercontent.com/caijianhong/template-poly/main/poly/modint.hpp

update on 2024.10.7:upload modint.min.hpp 作为一个关掉 clang-format 的比较舒服的版本

modint (clang-format on, version 3, support dynamic_modint)
template <class T>
using must_int = enable_if_t<is_integral<T>::value, int>;
template <unsigned umod>
struct modint {
  static constexpr int mod = umod;
  unsigned v;
  modint() : v(0) {}
  template <class T, must_int<T> = 0>
  modint(T x) {
    x %= mod;
    v = x < 0 ? x + mod : x;
  }
  modint operator+() const { return *this; }
  modint operator-() const { return modint() - *this; }
  friend int raw(const modint &self) { return self.v; }
  friend ostream& operator<<(ostream& os, const modint &self) { 
    return os << raw(self); 
  }
  modint &operator+=(const modint &rhs) {
    v += rhs.v;
    if (v >= umod) v -= umod;
    return *this;
  }
  modint &operator-=(const modint &rhs) {
    v -= rhs.v;
    if (v >= umod) v += umod;
    return *this;
  }
  modint &operator*=(const modint &rhs) {
    v = 1ull * v * rhs.v % umod;
    return *this;
  }
  modint &operator/=(const modint &rhs) {
    assert(rhs.v);
    return *this *= qpow(rhs, mod - 2);
  }
  template <class T, must_int<T> = 0>
  friend modint qpow(modint a, T b) {
    modint r = 1;
    for (; b; b >>= 1, a *= a)
      if (b & 1) r *= a;
    return r;
  }
  friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
  friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
  friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
  friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
  bool operator==(const modint &rhs) const { return v == rhs.v; }
  bool operator!=(const modint &rhs) const { return v != rhs.v; }
};

dynamic_modint:去掉 static constexpr int mod = P; 改为 static unsigned P;template <unsigned P> 改为 template<int id>,类外进行 P 的定义。

template <class T>
using must_int = enable_if_t<is_integral<T>::value, int>;
template <int id>
struct modint {
  static int mod;
  static unsigned umod;
  static void setmod(int p) { mod = umod = p; }
  unsigned v;
  modint() : v(0) {}
  template <class T, must_int<T> = 0>
  modint(T x) {
    x %= mod;
    v = x < 0 ? x + mod : x;
  }
  modint operator+() const { return *this; }
  modint operator-() const { return modint() - *this; }
  friend int raw(const modint &self) { return self.v; }
  friend ostream& operator<<(ostream& os, const modint &self) { 
    return os << raw(self); 
  }
  modint &operator+=(const modint &rhs) {
    v += rhs.v;
    if (v >= umod) v -= umod;
    return *this;
  }
  modint &operator-=(const modint &rhs) {
    v -= rhs.v;
    if (v >= umod) v += umod;
    return *this;
  }
  modint &operator*=(const modint &rhs) {
    v = 1ull * v * rhs.v % umod;
    return *this;
  }
  modint &operator/=(const modint &rhs) {
    assert(rhs.v);
    return *this *= qpow(rhs, mod - 2);
  }
  template <class T, must_int<T> = 0>
  friend modint qpow(modint a, T b) {
    modint r = 1;
    for (; b; b >>= 1, a *= a)
      if (b & 1) r *= a;
    return r;
  }
  friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
  friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
  friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
  friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
  bool operator==(const modint &rhs) const { return v == rhs.v; }
  bool operator!=(const modint &rhs) const { return v != rhs.v; }
};
template <int id>
unsigned modint<id>::umod;
template <int id>
int modint<id>::mod;

设置模数例:

typedef modint<-1> mint;
mint.setmod(998244353);
dynamic_modint with Montgomery Reduction
template <class T>
using must_int = enable_if_t<is_integral<T>::value>*;
struct Montgomery {
  typedef unsigned u32;
  typedef unsigned long long u64;
  u32 umod, niv, R, R2;
  Montgomery() {}
  Montgomery(u32 m) : umod(m), niv(2 + m), R(-m % m), R2(-(u64)m % m) {
    for (int i = 0; i < 4; i++) niv *= 2 + m * niv;
  }
  u32 reduce(u64 x) {
    x = (x + (u64)((u32)x * niv) * umod) >> 32;
    return x >= umod ? x - umod : x;
  }
};
template <int id>
struct modint {
  static int mod;
  static unsigned umod;
  static Montgomery mont;
  static void setmod(int p) {
    mod = umod = p;
    mont = Montgomery(p);
  }
  unsigned v;
  modint() : v(0) {}
  modint(unsigned x) : v(mont.reduce(1ull * x * mont.R2)) {}
  template <class T, must_int<T> = nullptr>
  modint(T x) : modint((unsigned)(x %= mod, x < 0 ? x + mod : x)) {}
  modint operator+() const { return *this; }
  modint operator-() const { return modint() - *this; }
  friend int raw(const modint &self) { return mont.reduce(self.v); }
  friend ostream &operator<<(ostream &os, const modint &self) {
    return os << raw(self);
  }
  modint &operator+=(const modint &rhs) {
    v += rhs.v;
    if (v >= umod) v -= umod;
    return *this;
  }
  modint &operator-=(const modint &rhs) {
    v -= rhs.v;
    if (v >= umod) v += umod;
    return *this;
  }
  modint &operator*=(const modint &rhs) {
    v = mont.reduce(1ull * v * rhs.v);
    return *this;
  }
  modint &operator/=(const modint &rhs) {
    assert(rhs.v);
    return *this *= qpow(rhs, mod - 2);
  }
  template <class T, must_int<T> = nullptr>
  friend modint qpow(modint a, T b) {
    modint r = 1;
    for (; b; b >>= 1, a *= a)
      if (b & 1) r *= a;
    return r;
  }
  friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
  friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
  friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
  friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
  bool operator==(const modint &rhs) const { return v == rhs.v; }
  bool operator!=(const modint &rhs) const { return v != rhs.v; }
};
template <int id>
unsigned modint<id>::umod;
template <int id>
int modint<id>::mod;
template <int id>
Montgomery modint<id>::mont;
modint (clang-format on, version 2)
template <unsigned P> struct modint {
    unsigned v; modint() : v(0) {}
    template <class T> modint(T x) { x %= (int)P, v = x < 0 ? x + P : x; }
    modint operator+() const { return *this; }
    modint operator-() const { return modint(0) - *this; }
    modint inv() const { return assert(v), qpow(*this, P - 2); }
    friend int raw(const modint &self) { return self.v; }
    template <class T> friend modint qpow(modint a, T b) {
        modint r = 1;
        for (; b; b >>= 1, a *= a) if (b & 1) r *= a;
        return r;
    }
    modint &operator+=(const modint &rhs) { if (v += rhs.v, v >= P) v -= P; return *this; }
    modint &operator-=(const modint &rhs) { if (v -= rhs.v, v >= P) v += P; return *this; }
    modint &operator*=(const modint &rhs) { v = 1ull * v * rhs.v % P; return *this; }
    modint &operator/=(const modint &rhs) { return *this *= rhs.inv(); }
    friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
    friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
    friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
    friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
    friend bool operator==(const modint &lhs, const modint &rhs) { return lhs.v == rhs.v; }
    friend bool operator!=(const modint &lhs, const modint &rhs) { return lhs.v != rhs.v; }
};
modint by cjl
struct m{
	int x;m(int o=0){x=o;}m(lll o){x=o%mod;}m&operator+=(m a){return(x+=a.x)%=mod,*this;}m&operator-=(m a){return(x+=mod-a.x)%=mod,*this;}
	m&operator*=(m a){return(x=1ll*x*a.x%mod),*this;}m&operator^=( int b){m a=*this;x=1;while(b)(b&1)&&(*this*=a,1),a*=a,b>>=1;return*this;}
	m&operator/=(m a){return*this*=(a^=mod-2);}friend m operator+(m a,m b){return a+=b;}friend m operator-(m a,m b){return a-=b;}
	friend m operator*(m a,m b){return a*=b;}friend m operator/(m a,m b){return a/=b;}friend m operator^(m a, int b){return a^=b;}
	m operator-(){return 0-*this;}bool operator==(const m b)const{return x==b.x;}
};
仿照 atcoder library 写的:modint (version 1)
typedef long long LL;
template<unsigned P> struct modint{
	unsigned v; modint():v(0){}
	template<class T> modint(T x):v((x%int(P)+int(P))%int(P)){}
	modint operator-()const{return modint(P-v);}
	modint inv()const{return assert(v),qpow(*this,LL(P)-2);}
	modint&operator+=(const modint&rhs){if(v+=rhs.v,v>=P) v-=P; return *this;}
	modint&operator-=(const modint&rhs){return *this+=-rhs;}
	modint&operator*=(const modint&rhs){v=1ull*v*rhs.v%P; return *this;}
	modint&operator/=(const modint&rhs){return *this*=rhs.inv();}
	friend int raw(const modint&self){return self.v;}
	friend modint qpow(modint a,LL b){modint r=1;for(;b;b>>=1,a*=a) if(b&1) r*=a; return r;}
	friend modint operator+(modint lhs,const modint&rhs){return lhs+=rhs;}
	friend modint operator-(modint lhs,const modint&rhs){return lhs-=rhs;}
	friend modint operator*(modint lhs,const modint&rhs){return lhs*=rhs;}
	friend modint operator/(modint lhs,const modint&rhs){return lhs/=rhs;}
	friend bool operator==(const modint&lhs,const modint&rhs){return lhs.v==rhs.v;}
	friend bool operator!=(const modint&lhs,const modint&rhs){return lhs.v!=rhs.v;}
};

注意那个 int(P) 非常重要,因为数字和模数都是 unsigned,传进来负数会模成奇怪东西。

精简版:但好像并没有(using modint = Z)
template<unsigned P> struct modint{
	unsigned v; using Z=modint; modint():v(0){}
	template<class T> modint(T x):v((x%int(P)+int(P))%int(P)){}
	Z operator-()const{return Z(P-v);}
	Z inv()const{return qpow(*this,LL(P)-2);}
	Z&operator+=(Z b){if(v+=b.v,v>=P) v-=P; return *this;}
	Z&operator-=(Z b){return *this+=-b;}
	Z&operator*=(Z b){v=1ull*v*b.v%P; return *this;}
	Z&operator/=(Z b){return *this*=b.inv();}
	friend int raw(const Z&self){return self.v;}
	friend Z qpow(Z a,LL b){Z r=1;for(;b;b>>=1,a*=a) if(b&1) r*=a; return r;}
	friend Z operator+(Z a,Z b){return a+=b;}
	friend Z operator-(Z a,Z b){return a-=b;}
	friend Z operator*(Z a,Z b){return a*=b;}
	friend Z operator/(Z a,Z b){return a/=b;}
};
使用例
typedef modint<998244353> Z;
int main(){
//	#ifdef LOCAL
//	 	freopen("input.in","r",stdin);
//	#endif
	Z a=-1,b=-2;
	printf("%d %d\n",raw(a),raw(b));
	printf("%d\n",raw(a+b));
	printf("%d\n",raw(a-b));
	printf("%d\n",raw(a*b));
	printf("%d\n",raw(a/b));
	printf("%d\n",raw(qpow(a,2)));
	return 0;
}
C_prime (clang-format off)
using mint=modint<998244353>;
template<int N> struct C_prime{
    mint fac[N+10],ifac[N+10];
    C_prime(){
        for(int i=raw(fac[0]=1);i<=N;i++) fac[i]=fac[i-1]*i;
        ifac[N]=1/fac[N];for(int i=N;i>=1;i--) ifac[i-1]=ifac[i]*i;
    }
    mint operator()(int n,int m){return n>=m?fac[n]*ifac[m]*ifac[n-m]:0;}
};
C_prime (clang-format on)
typedef modint<998244353> mint;
template <int N>
struct C_prime {
  mint fac[N + 10], ifac[N + 10];
  C_prime() {
    for (int i = raw(fac[0] = 1); i <= N; i++) fac[i] = fac[i - 1] * i;
    ifac[N] = 1 / fac[N];
    for (int i = N; i >= 1; i--) ifac[i - 1] = ifac[i] * i;
  }
  mint operator()(int n, int m) {
    return n >= m ? fac[n] * ifac[m] * ifac[n - m] : 0;
  }
};

注意:modint 无法存 -1,所以要用到 -1 就 throw 一个异常出去。

注意:int raw(modint) 是用来返回这个 modint 里面的值的,和 atcoder-library 是相反的(怎么没发现)。然后并没有不进行取模的 modint 的构造函数,但是 modint::v 是 public 的。然后没写自增和自减。

本家(atcoder/modint.hpp)
#ifndef ATCODER_MODINT_HPP
#define ATCODER_MODINT_HPP 1

#include <cassert>
#include <numeric>
#include <type_traits>

#ifdef _MSC_VER
#include <intrin.h>
#endif

#include "atcoder/internal_math"
#include "atcoder/internal_type_traits"

namespace atcoder {

namespace internal {

struct modint_base {};
struct static_modint_base : modint_base {};

template <class T> using is_modint = std::is_base_of<modint_base, T>;
template <class T> using is_modint_t = std::enable_if_t<is_modint<T>::value>;

}  // namespace internal

template <int m, std::enable_if_t<(1 <= m)>* = nullptr>
struct static_modint : internal::static_modint_base {
    using mint = static_modint;

  public:
    static constexpr int mod() { return m; }
    static mint raw(int v) {
        mint x;
        x._v = v;
        return x;
    }

    static_modint() : _v(0) {}
    template <class T, internal::is_signed_int_t<T>* = nullptr>
    static_modint(T v) {
        long long x = (long long)(v % (long long)(umod()));
        if (x < 0) x += umod();
        _v = (unsigned int)(x);
    }
    template <class T, internal::is_unsigned_int_t<T>* = nullptr>
    static_modint(T v) {
        _v = (unsigned int)(v % umod());
    }

    unsigned int val() const { return _v; }

    mint& operator++() {
        _v++;
        if (_v == umod()) _v = 0;
        return *this;
    }
    mint& operator--() {
        if (_v == 0) _v = umod();
        _v--;
        return *this;
    }
    mint operator++(int) {
        mint result = *this;
        ++*this;
        return result;
    }
    mint operator--(int) {
        mint result = *this;
        --*this;
        return result;
    }

    mint& operator+=(const mint& rhs) {
        _v += rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator-=(const mint& rhs) {
        _v -= rhs._v;
        if (_v >= umod()) _v += umod();
        return *this;
    }
    mint& operator*=(const mint& rhs) {
        unsigned long long z = _v;
        z *= rhs._v;
        _v = (unsigned int)(z % umod());
        return *this;
    }
    mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }

    mint operator+() const { return *this; }
    mint operator-() const { return mint() - *this; }

    mint pow(long long n) const {
        assert(0 <= n);
        mint x = *this, r = 1;
        while (n) {
            if (n & 1) r *= x;
            x *= x;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        if (prime) {
            assert(_v);
            return pow(umod() - 2);
        } else {
            auto eg = internal::inv_gcd(_v, m);
            assert(eg.first == 1);
            return eg.second;
        }
    }

    friend mint operator+(const mint& lhs, const mint& rhs) {
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs) {
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs) {
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs) {
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const mint& lhs, const mint& rhs) {
        return lhs._v == rhs._v;
    }
    friend bool operator!=(const mint& lhs, const mint& rhs) {
        return lhs._v != rhs._v;
    }

  private:
    unsigned int _v;
    static constexpr unsigned int umod() { return m; }
    static constexpr bool prime = internal::is_prime<m>;
};

template <int id> struct dynamic_modint : internal::modint_base {
    using mint = dynamic_modint;

  public:
    static int mod() { return (int)(bt.umod()); }
    static void set_mod(int m) {
        assert(1 <= m);
        bt = internal::barrett(m);
    }
    static mint raw(int v) {
        mint x;
        x._v = v;
        return x;
    }

    dynamic_modint() : _v(0) {}
    template <class T, internal::is_signed_int_t<T>* = nullptr>
    dynamic_modint(T v) {
        long long x = (long long)(v % (long long)(mod()));
        if (x < 0) x += mod();
        _v = (unsigned int)(x);
    }
    template <class T, internal::is_unsigned_int_t<T>* = nullptr>
    dynamic_modint(T v) {
        _v = (unsigned int)(v % mod());
    }

    unsigned int val() const { return _v; }

    mint& operator++() {
        _v++;
        if (_v == umod()) _v = 0;
        return *this;
    }
    mint& operator--() {
        if (_v == 0) _v = umod();
        _v--;
        return *this;
    }
    mint operator++(int) {
        mint result = *this;
        ++*this;
        return result;
    }
    mint operator--(int) {
        mint result = *this;
        --*this;
        return result;
    }

    mint& operator+=(const mint& rhs) {
        _v += rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator-=(const mint& rhs) {
        _v += mod() - rhs._v;
        if (_v >= umod()) _v -= umod();
        return *this;
    }
    mint& operator*=(const mint& rhs) {
        _v = bt.mul(_v, rhs._v);
        return *this;
    }
    mint& operator/=(const mint& rhs) { return *this = *this * rhs.inv(); }

    mint operator+() const { return *this; }
    mint operator-() const { return mint() - *this; }

    mint pow(long long n) const {
        assert(0 <= n);
        mint x = *this, r = 1;
        while (n) {
            if (n & 1) r *= x;
            x *= x;
            n >>= 1;
        }
        return r;
    }
    mint inv() const {
        auto eg = internal::inv_gcd(_v, mod());
        assert(eg.first == 1);
        return eg.second;
    }

    friend mint operator+(const mint& lhs, const mint& rhs) {
        return mint(lhs) += rhs;
    }
    friend mint operator-(const mint& lhs, const mint& rhs) {
        return mint(lhs) -= rhs;
    }
    friend mint operator*(const mint& lhs, const mint& rhs) {
        return mint(lhs) *= rhs;
    }
    friend mint operator/(const mint& lhs, const mint& rhs) {
        return mint(lhs) /= rhs;
    }
    friend bool operator==(const mint& lhs, const mint& rhs) {
        return lhs._v == rhs._v;
    }
    friend bool operator!=(const mint& lhs, const mint& rhs) {
        return lhs._v != rhs._v;
    }

  private:
    unsigned int _v;
    static internal::barrett bt;
    static unsigned int umod() { return bt.umod(); }
};
template <int id> internal::barrett dynamic_modint<id>::bt(998244353);

using modint998244353 = static_modint<998244353>;
using modint1000000007 = static_modint<1000000007>;
using modint = dynamic_modint<-1>;

namespace internal {

template <class T>
using is_static_modint = std::is_base_of<internal::static_modint_base, T>;

template <class T>
using is_static_modint_t = std::enable_if_t<is_static_modint<T>::value>;

template <class> struct is_dynamic_modint : public std::false_type {};
template <int id>
struct is_dynamic_modint<dynamic_modint<id>> : public std::true_type {};

template <class T>
using is_dynamic_modint_t = std::enable_if_t<is_dynamic_modint<T>::value>;

}  // namespace internal

}  // namespace atcoder

#endif  // ATCODER_MODINT_HPP

posted @ 2023-08-12 16:32  caijianhong  阅读(386)  评论(0编辑  收藏  举报