Montgomery Reduction算法流程与实际实现简介

Montgomery Reduction 算法流程与实际实现

下面默认对于模数\(m\)取模,由于这篇文章的重点是实现(其实就是我自己存一下板子),因此没有证明

使用注意:

Montgomery Reduction 相较于 Barret Reduction来说,不需要使用__int128

但是有着更高的封装程度,因为涉及到普通数与Montgomery Reduction运算中间量的转化

另外,常见的Montgomery Reduction 在编程竞赛中的应用 要求模数为奇数

但是在Min25博客上来看,Montgomery似乎有着更高的效率

Montgomery Reduction算法思想简介

在计算取模运算的过程中,将每一个元素\(T\)都乘上一个特定的值\(R(R>m,\gcd(R,m)=1)\)

用特殊的方法处理相乘时除掉一个\(R\)的过程,从而避免取模运算

在使用的模数为常量时,编译器通常会自动加入Barrett reduction的优化,因此实际上这个算法对于动态模数的情形更为适用

(你自己真不一定写得过STL,但是确实可以比STL块)

\[\ \]

编程上的应用简介

对于\(m\)为奇数的情况,取\(R=2^{32}\),用 自然溢出来代替取模/位运算位移代替除法 来加速运算

我们还需要令\(m' = -m^{-1} \mod R\),有结论

对于某一个数\(T,0 \leq T < mR\),若令\(U = Tm’ \mod R\),则 \(\frac{T+Um}{R}\)为整数,且 \(\frac{T+Um}{R}=TR^{-1} \mod m\)

那么我们在计算\(\frac{T}{R}\)时,实际上只需要计算\(\frac{T+Um}{R}\),可以预处理\(m'\),溢出计算\(Tm'\),位运算左移计算\(\frac{T+Um}{R}\)

实际使用时的实现,可以用一个类实现以下方法

在实现时需要尤其注意不要出现溢出

1.预处理\(m'\)

\((R-\lfloor \frac{R}{m}\rfloor )\cdot (R\mod m)\)

using u32=unsigned;
using i32=int;
using u64=unsigned long long;
using i64=long long;
// inv=m'
u32 m;
u32 getinv(){
	u32 inv=m;
	for(int i=0;i<4;++i) inv*=2-inv*m;
}

2.reduce方法

u32 reduce(u64 x) {
    u32 y = u32(x >> 32) - u32((u64(u32(x)*inv)*m) >> 32);
    // 先取u32(x)得到x mod R ,然后再转成u64进行乘法
    return i32(y) < 0 ? y + m : y;
}

3.普通数转Montgomery Reduction

我们要计算\(x\rightarrow xR=x\cdot 2^{32}\),但是如果直接用取模就失去了意义。。。

方法是快速计算\(x\cdot R^2\),然后reduce一次

u32 R2=-u64(m)%m;
u32 intToMont(i32 x){
    return reduce(u64(x)*R2);
}

\[\ \]

4.Montomery运算

u32 Add(u32 x,u32 y) {
    x+=y-m;
    return i32(x<0)?x+m:x;
}
u32 Dec(u32 x,u32 y){
    x-=y;
    return i32(x<0)?x+m:x;
}
u32 Mul(u32 x,u32 y){
    return reduce(u64(x)*y);
}

\[\ \]

5.Montomery Reduction转普通数

i32 get(u32 x){
    return reduce(x);
}

封装之后,得到板子一号,这个是动态模数的。。。

实现上可能的误区:

为什么不用-inv?避免加法,原因是加法取模要和m比较

同样的,下面的i32(y)<0语句可以被替换为y>=m(负数溢出),看似减少一次类型转换,但是实际上0作为常量比较快得多

加法运算时也是类似的原因,x>=m的比较实在太慢,因此强制减去一个m,然后和0比

using u32=uint32_t;
using i32=int32_t;
using u64=uint64_t;
using i64=int64_t;

static u32 m,inv,r2,P;
u32 getinv(){
    u32 inv=m;
    for(int i=0;i<4;++i) inv*=2-inv*m;
    return inv;
}
struct Mont{
private :
    u32 x;
public :
    static u32 reduce(u64 x){ 
        u32 y=u32(x>>32)-u32((u64(u32(x)*inv)*m)>>32);
        return i32(y)<0?y+m:y;
    }
    Mont(){ ; }
    Mont(i32 x):x(reduce(u64(x)*r2)) { }
    Mont& operator += (const Mont &rhs) { return x+=rhs.x-m,is32(x)<0&&(x+=m),*this; }
    Mont& operator -= (const Mont &rhs) { return x-=rhs.x,i32(x)<0&&(x+=m),*this; }
    Mont& operator *= (const Mont &rhs) { return x=reduce(u64(x)*rhs.x),*this; }
    friend Mont operator + (Mont x,const Mont &y) { return x+=y; }
    friend Mont operator - (Mont x,const Mont &y) { return x-=y; }
    friend Mont operator * (Mont x,const Mont &y) { return x*=y; }
    i32 get(){ return reduce(x); }
};
void Init(int m) { 
    ::m=m;
    inv=-getinv();
    r2=-u64(m)%m;
}

动态模数的方法,计算\(5\cdot 10^7!\mod 998244353\)在duck.ac上评测结果,时间单位是微秒\(\mu s\)

Naive Mod     : 213689172  Time: 518352
My Montgomery : 213689172  Time: 192195

\[\ \]

\[\ \]

这个是我自己写的静态模数的,因为模数是静态的,所以不需要一定和0比较大小

template <uint32_t m> struct Mont{
private :
    using u32=uint32_t;
    using i32=int32_t;
    using u64=uint64_t;
    using i64=int64_t;
    static constexpr u32 getinv(){
        u32 inv=m;
        for(int i=0;i<4;++i) inv*=2-inv*m;
        return inv;
    }
    static constexpr u32 inv=-getinv(),r2=-u64(m)%m;
    u32 x;
public :
    static constexpr u32 reduce(u64 x){ 
        u32 y=(x+u64(u32(x)*inv)*m)>>32;
        return y>=m?y-m:y;
    }
    Mont(){ ; }
    constexpr Mont(i32 x):x(reduce(u64(x)*r2)) { }
    constexpr Mont& operator += (const Mont &rhs) { return x+=rhs.x-m,x>=m&&(x+=m),*this; }
    constexpr Mont& operator -= (const Mont &rhs) { return x-=rhs.x,x>=m&&(x+=m),*this; }
    constexpr Mont& operator *= (const Mont &rhs) { return x=reduce(u64(x)*rhs.x),*this; }
    constexpr friend Mont operator + (Mont x,const Mont &y) { return x+=y; }
    constexpr friend Mont operator - (Mont x,const Mont &y) { return x-=y; }
    constexpr friend Mont operator * (Mont x,const Mont &y) { return x*=y; }
    constexpr i32 get(){ return reduce(x); }
} ;

这个是摘自LOJ多项式乘法 hly1204的提交记录

个人解读:实际上每次存储的是\(x \mod 2m\)的值,避免了reduce时的加减取模

// from https://min-25.hatenablog.com/entry/2017/08/20/171214
template <std::uint32_t P> struct MontgomeryModInt32 {
public:
  using i32 = std::int32_t;
  using u32 = std::uint32_t;
  using i64 = std::int64_t;
  using u64 = std::uint64_t;

private:
  u32 v;

  static constexpr u32 get_r() {
    u32 iv = P;
    for (u32 i = 0; i != 4; ++i) iv *= 2 - P * iv;
    return iv;
  }

  static constexpr u32 r = -get_r(), r2 = -u64(P) % P;

  static_assert((P & 1) == 1);
  static_assert(r * P == -1);
  static_assert(P < (1 << 30));

public:
  static constexpr u32 pow_mod(u32 x, u64 y) {
    if ((y %= P - 1) < 0) y += P - 1;
    u32 res = 1;
    for (; y != 0; y >>= 1, x = u64(x) * x % P)
      if (y & 1) res = u64(res) * x % P;
    return res;
  }

  static constexpr u32 get_pr() {
    u32 tmp[32] = {}, cnt = 0;
    const u64 phi = P - 1;
    u64 m = phi;
    for (u64 i = 2; i * i <= m; ++i) {
      if (m % i == 0) {
        tmp[cnt++] = i;
        while (m % i == 0) m /= i;
      }
    }
    if (m > 1) tmp[cnt++] = m;
    for (u64 res = 2; res <= phi; ++res) {
      bool flag = true;
      for (u32 i = 0; i != cnt && flag; ++i) flag &= pow_mod(res, phi / tmp[i]) != 1;
      if (flag) return res;
    }
    return 0;
  }

  MontgomeryModInt32() = default;
  ~MontgomeryModInt32() = default;
  constexpr MontgomeryModInt32(u32 v) : v(reduce(u64(v) * r2)) {}
  constexpr MontgomeryModInt32(const MontgomeryModInt32 &rhs) : v(rhs.v) {}
  static constexpr u32 reduce(u64 x) { return x + (u64(u32(x) * r) * P) >> 32; }
  constexpr u32 get() const {
    u32 res = reduce(v);
    return res - (P & -(res >= P));
  }
  explicit constexpr operator u32() const { return get(); }
  explicit constexpr operator i32() const { return i32(get()); }
  constexpr MontgomeryModInt32 &operator=(const MontgomeryModInt32 &rhs) {
    return v = rhs.v, *this;
  }
  constexpr MontgomeryModInt32 operator-() const {
    MontgomeryModInt32 res;
    return res.v = (P << 1 & -(v != 0)) - v, res;
  }
  constexpr MontgomeryModInt32 inv() const { return pow(-1); }
  constexpr MontgomeryModInt32 &operator+=(const MontgomeryModInt32 &rhs) {
    return v += rhs.v - (P << 1), v += P << 1 & -(i32(v) < 0), *this;
  }
  constexpr MontgomeryModInt32 &operator-=(const MontgomeryModInt32 &rhs) {
    return v -= rhs.v, v += P << 1 & -(i32(v) < 0), *this;
  }
  constexpr MontgomeryModInt32 &operator*=(const MontgomeryModInt32 &rhs) {
    return v = reduce(u64(v) * rhs.v), *this;
  }
  constexpr MontgomeryModInt32 &operator/=(const MontgomeryModInt32 &rhs) {
    return this->operator*=(rhs.inv());
  }
  friend MontgomeryModInt32 operator+(const MontgomeryModInt32 &lhs,
                                      const MontgomeryModInt32 &rhs) {
    return MontgomeryModInt32(lhs) += rhs;
  }
  friend MontgomeryModInt32 operator-(const MontgomeryModInt32 &lhs,
                                      const MontgomeryModInt32 &rhs) {
    return MontgomeryModInt32(lhs) -= rhs;
  }
  friend MontgomeryModInt32 operator*(const MontgomeryModInt32 &lhs,
                                      const MontgomeryModInt32 &rhs) {
    return MontgomeryModInt32(lhs) *= rhs;
  }
  friend MontgomeryModInt32 operator/(const MontgomeryModInt32 &lhs,
                                      const MontgomeryModInt32 &rhs) {
    return MontgomeryModInt32(lhs) /= rhs;
  }
  friend std::istream &operator>>(std::istream &is, MontgomeryModInt32 &rhs) {
    return is >> rhs.v, rhs.v = reduce(u64(rhs.v) * r2), is;
  }
  friend std::ostream &operator<<(std::ostream &os, const MontgomeryModInt32 &rhs) {
    return os << rhs.get();
  }
  constexpr MontgomeryModInt32 pow(i64 y) const {
    if ((y %= P - 1) < 0) y += P - 1; // phi(P) = P - 1, assume P is a prime number
    MontgomeryModInt32 res(1), x(*this);
    for (; y != 0; y >>= 1, x *= x)
      if (y & 1) res *= x;
    return res;
  }
};

这个是计算\(5\cdot 10^7!\mod 998244353\)在duck.ac上的测试结果

Naive Mod      : 213689172  Time: 180649
My Montgomery  : 213689172  Time: 178217
His Montgomery : 213689172  Time: 152847

这个是计算\(7\cdot 10^7!\mod 998244353\)在duck.ac上的测试结果

Naive Mod      : 939830261  Time: 252908
My Montgomery  : 939830261  Time: 249476
His Montgomery : 939830261  Time: 213986

还可以看Min25博客里下面的ModInt64板本

传送门

下面自己实现的\(\mod 2m\)版本,差不多也是最终版本了,跑起来和hly1204差不多

静态版本

template <uint32_t m> struct Mont2{
private :
    using u32=uint32_t;
    using i32=int32_t;
    using u64=uint64_t;
    using i64=int64_t;
    static constexpr u32 m2=m<<1;
    static constexpr u32 getinv(){
        u32 inv=m;
        for(int i=0;i<4;++i) inv*=2-inv*m;
        return inv;
    }
    static constexpr u32 inv=-getinv(),r2=-u64(m)%m;
    u32 x;
public :
    static constexpr u32 reduce(u64 x){ 
        return (x+u64(u32(x)*inv)*m)>>32;
    }
    Mont2(){ ; }
    constexpr Mont2(i32 x):x(reduce(u64(x)*r2)) { }
    constexpr Mont2& operator += (const Mont2 &rhs) { return x+=rhs.x-m2,x>=m2&&(x+=m2),*this; }
    constexpr Mont2& operator -= (const Mont2 &rhs) { return x-=rhs.x,x>=m2&&(x+=m2),*this; }
    constexpr Mont2& operator *= (const Mont2 &rhs) { return x=reduce(u64(x)*rhs.x),*this; }
    constexpr friend Mont2 operator + (Mont2 x,const Mont2 &y) { return x+=y; }
    constexpr friend Mont2 operator - (Mont2 x,const Mont2 &y) { return x-=y; }
    constexpr friend Mont2 operator * (Mont2 x,const Mont2 &y) { return x*=y; }
    constexpr i32 get(){ 
        u32 res=reduce(x); 
        return res>=m?res-m:res;
    }
} ;

板子各有优劣.jpg

另外这是Int_To_Montgomery加法的速度,\(7\cdot 10^7\)次加法与类型转换

Naive :        : 305907824 80074
My Montgomery  : 305907824 109479
My Montgomery2 : 305907824 99896
His Montgomery : 305907824 117449

动态版本

using u32=uint32_t;
using i32=int32_t;
using u64=uint64_t;
using i64=int64_t;

static u32 m,m2,inv,r2,P;
u32 getinv(){
    u32 inv=m;
    for(int i=0;i<4;++i) inv*=2-inv*m;
    return inv;
}
struct Mont{
private :
    u32 x;
public :
    static u32 reduce(u64 x){ 
        u32 y=(x+u64(u32(x)*inv)*m)>>32;
        return i32(y)<0?y+m:y;
    }
    Mont(){ ; }
    Mont(i32 x):x(reduce(u64(x)*r2)) { }
    Mont& operator += (const Mont &rhs) { return x+=rhs.x-m2,i32(x)<0&&(x+=m2),*this; }
    Mont& operator -= (const Mont &rhs) { return x-=rhs.x,i32(x)<0&&(x+=m2),*this; }
    Mont& operator *= (const Mont &rhs) { return x=reduce(u64(x)*rhs.x),*this; }
    friend Mont operator + (Mont x,const Mont &y) { return x+=y; }
    friend Mont operator - (Mont x,const Mont &y) { return x-=y; }
    friend Mont operator * (Mont x,const Mont &y) { return x*=y; }
    i32 get(){ 
        u32 res=reduce(x);
        return res>=m?res-m:res;
    }
};
void Init(int m) { 
    ::m=m,m2=m*2;
    inv=-getinv();
    r2=-u64(m)%m;
}

这个动态模板计算\(5\cdot 10^7!\mod 998244353\)

Naive Mod      : 213689172 494061 (稍微修改了一下暴力的细节。。)
My Montgomery2 : 213689172 152849

不得不说duck.ac真的很nb

posted @ 2020-12-01 19:56  chasedeath  阅读(1461)  评论(0编辑  收藏  举报