Montgomery Reduction算法流程与实际实现简介
Montgomery Reduction 算法流程与实际实现
下面默认对于模数\(m\)取模,由于这篇文章的重点是实现(其实就是我自己存一下板子),因此没有证明
使用注意:
Montgomery Reduction 相较于 Barret Reduction来说,不需要使用__int128
但是有着更高的封装程度,因为涉及到普通数与Montgomery Reduction运算中间量的转化
另外,常见的Montgomery Reduction 在编程竞赛中的应用 要求模数为奇数
但是在Min25博客上来看,Montgomery似乎有着更高的效率
Montgomery Reduction算法思想简介
在计算取模运算的过程中,将每一个元素\(T\)都乘上一个特定的值\(R(R>m,\gcd(R,m)=1)\)
用特殊的方法处理相乘时除掉一个\(R\)的过程,从而避免取模运算
在使用的模数为常量时,编译器通常会自动加入Barrett reduction的优化,因此实际上这个算法对于动态模数的情形更为适用
(你自己真不一定写得过STL,但是确实可以比STL块)
编程上的应用简介
对于\(m\)为奇数的情况,取\(R=2^{32}\),用 自然溢出来代替取模/位运算位移代替除法 来加速运算
我们还需要令\(m' = -m^{-1} \mod R\),有结论
对于某一个数\(T,0 \leq T < mR\),若令\(U = Tm’ \mod R\),则 \(\frac{T+Um}{R}\)为整数,且 \(\frac{T+Um}{R}=TR^{-1} \mod m\)
那么我们在计算\(\frac{T}{R}\)时,实际上只需要计算\(\frac{T+Um}{R}\),可以预处理\(m'\),溢出计算\(Tm'\),位运算左移计算\(\frac{T+Um}{R}\)
实际使用时的实现,可以用一个类实现以下方法
在实现时需要尤其注意不要出现溢出
1.预处理\(m'\)
\((R-\lfloor \frac{R}{m}\rfloor )\cdot (R\mod m)\)
using u32=unsigned;
using i32=int;
using u64=unsigned long long;
using i64=long long;
// inv=m'
u32 m;
u32 getinv(){
u32 inv=m;
for(int i=0;i<4;++i) inv*=2-inv*m;
}
2.reduce方法
u32 reduce(u64 x) {
u32 y = u32(x >> 32) - u32((u64(u32(x)*inv)*m) >> 32);
// 先取u32(x)得到x mod R ,然后再转成u64进行乘法
return i32(y) < 0 ? y + m : y;
}
3.普通数转Montgomery Reduction
我们要计算\(x\rightarrow xR=x\cdot 2^{32}\),但是如果直接用取模就失去了意义。。。
方法是快速计算\(x\cdot R^2\),然后reduce一次
u32 R2=-u64(m)%m;
u32 intToMont(i32 x){
return reduce(u64(x)*R2);
}
4.Montomery运算
u32 Add(u32 x,u32 y) {
x+=y-m;
return i32(x<0)?x+m:x;
}
u32 Dec(u32 x,u32 y){
x-=y;
return i32(x<0)?x+m:x;
}
u32 Mul(u32 x,u32 y){
return reduce(u64(x)*y);
}
5.Montomery Reduction转普通数
i32 get(u32 x){
return reduce(x);
}
封装之后,得到板子一号,这个是动态模数的。。。
实现上可能的误区:
为什么不用-inv?避免加法,原因是加法取模要和m比较
同样的,下面的i32(y)<0语句可以被替换为y>=m(负数溢出),看似减少一次类型转换,但是实际上0作为常量比较快得多
加法运算时也是类似的原因,x>=m的比较实在太慢,因此强制减去一个m,然后和0比
using u32=uint32_t;
using i32=int32_t;
using u64=uint64_t;
using i64=int64_t;
static u32 m,inv,r2,P;
u32 getinv(){
u32 inv=m;
for(int i=0;i<4;++i) inv*=2-inv*m;
return inv;
}
struct Mont{
private :
u32 x;
public :
static u32 reduce(u64 x){
u32 y=u32(x>>32)-u32((u64(u32(x)*inv)*m)>>32);
return i32(y)<0?y+m:y;
}
Mont(){ ; }
Mont(i32 x):x(reduce(u64(x)*r2)) { }
Mont& operator += (const Mont &rhs) { return x+=rhs.x-m,is32(x)<0&&(x+=m),*this; }
Mont& operator -= (const Mont &rhs) { return x-=rhs.x,i32(x)<0&&(x+=m),*this; }
Mont& operator *= (const Mont &rhs) { return x=reduce(u64(x)*rhs.x),*this; }
friend Mont operator + (Mont x,const Mont &y) { return x+=y; }
friend Mont operator - (Mont x,const Mont &y) { return x-=y; }
friend Mont operator * (Mont x,const Mont &y) { return x*=y; }
i32 get(){ return reduce(x); }
};
void Init(int m) {
::m=m;
inv=-getinv();
r2=-u64(m)%m;
}
动态模数的方法,计算\(5\cdot 10^7!\mod 998244353\)在duck.ac上评测结果,时间单位是微秒\(\mu s\)
Naive Mod : 213689172 Time: 518352
My Montgomery : 213689172 Time: 192195
这个是我自己写的静态模数的,因为模数是静态的,所以不需要一定和0比较大小
template <uint32_t m> struct Mont{
private :
using u32=uint32_t;
using i32=int32_t;
using u64=uint64_t;
using i64=int64_t;
static constexpr u32 getinv(){
u32 inv=m;
for(int i=0;i<4;++i) inv*=2-inv*m;
return inv;
}
static constexpr u32 inv=-getinv(),r2=-u64(m)%m;
u32 x;
public :
static constexpr u32 reduce(u64 x){
u32 y=(x+u64(u32(x)*inv)*m)>>32;
return y>=m?y-m:y;
}
Mont(){ ; }
constexpr Mont(i32 x):x(reduce(u64(x)*r2)) { }
constexpr Mont& operator += (const Mont &rhs) { return x+=rhs.x-m,x>=m&&(x+=m),*this; }
constexpr Mont& operator -= (const Mont &rhs) { return x-=rhs.x,x>=m&&(x+=m),*this; }
constexpr Mont& operator *= (const Mont &rhs) { return x=reduce(u64(x)*rhs.x),*this; }
constexpr friend Mont operator + (Mont x,const Mont &y) { return x+=y; }
constexpr friend Mont operator - (Mont x,const Mont &y) { return x-=y; }
constexpr friend Mont operator * (Mont x,const Mont &y) { return x*=y; }
constexpr i32 get(){ return reduce(x); }
} ;
个人解读:实际上每次存储的是\(x \mod 2m\)的值,避免了reduce时的加减取模
// from https://min-25.hatenablog.com/entry/2017/08/20/171214
template <std::uint32_t P> struct MontgomeryModInt32 {
public:
using i32 = std::int32_t;
using u32 = std::uint32_t;
using i64 = std::int64_t;
using u64 = std::uint64_t;
private:
u32 v;
static constexpr u32 get_r() {
u32 iv = P;
for (u32 i = 0; i != 4; ++i) iv *= 2 - P * iv;
return iv;
}
static constexpr u32 r = -get_r(), r2 = -u64(P) % P;
static_assert((P & 1) == 1);
static_assert(r * P == -1);
static_assert(P < (1 << 30));
public:
static constexpr u32 pow_mod(u32 x, u64 y) {
if ((y %= P - 1) < 0) y += P - 1;
u32 res = 1;
for (; y != 0; y >>= 1, x = u64(x) * x % P)
if (y & 1) res = u64(res) * x % P;
return res;
}
static constexpr u32 get_pr() {
u32 tmp[32] = {}, cnt = 0;
const u64 phi = P - 1;
u64 m = phi;
for (u64 i = 2; i * i <= m; ++i) {
if (m % i == 0) {
tmp[cnt++] = i;
while (m % i == 0) m /= i;
}
}
if (m > 1) tmp[cnt++] = m;
for (u64 res = 2; res <= phi; ++res) {
bool flag = true;
for (u32 i = 0; i != cnt && flag; ++i) flag &= pow_mod(res, phi / tmp[i]) != 1;
if (flag) return res;
}
return 0;
}
MontgomeryModInt32() = default;
~MontgomeryModInt32() = default;
constexpr MontgomeryModInt32(u32 v) : v(reduce(u64(v) * r2)) {}
constexpr MontgomeryModInt32(const MontgomeryModInt32 &rhs) : v(rhs.v) {}
static constexpr u32 reduce(u64 x) { return x + (u64(u32(x) * r) * P) >> 32; }
constexpr u32 get() const {
u32 res = reduce(v);
return res - (P & -(res >= P));
}
explicit constexpr operator u32() const { return get(); }
explicit constexpr operator i32() const { return i32(get()); }
constexpr MontgomeryModInt32 &operator=(const MontgomeryModInt32 &rhs) {
return v = rhs.v, *this;
}
constexpr MontgomeryModInt32 operator-() const {
MontgomeryModInt32 res;
return res.v = (P << 1 & -(v != 0)) - v, res;
}
constexpr MontgomeryModInt32 inv() const { return pow(-1); }
constexpr MontgomeryModInt32 &operator+=(const MontgomeryModInt32 &rhs) {
return v += rhs.v - (P << 1), v += P << 1 & -(i32(v) < 0), *this;
}
constexpr MontgomeryModInt32 &operator-=(const MontgomeryModInt32 &rhs) {
return v -= rhs.v, v += P << 1 & -(i32(v) < 0), *this;
}
constexpr MontgomeryModInt32 &operator*=(const MontgomeryModInt32 &rhs) {
return v = reduce(u64(v) * rhs.v), *this;
}
constexpr MontgomeryModInt32 &operator/=(const MontgomeryModInt32 &rhs) {
return this->operator*=(rhs.inv());
}
friend MontgomeryModInt32 operator+(const MontgomeryModInt32 &lhs,
const MontgomeryModInt32 &rhs) {
return MontgomeryModInt32(lhs) += rhs;
}
friend MontgomeryModInt32 operator-(const MontgomeryModInt32 &lhs,
const MontgomeryModInt32 &rhs) {
return MontgomeryModInt32(lhs) -= rhs;
}
friend MontgomeryModInt32 operator*(const MontgomeryModInt32 &lhs,
const MontgomeryModInt32 &rhs) {
return MontgomeryModInt32(lhs) *= rhs;
}
friend MontgomeryModInt32 operator/(const MontgomeryModInt32 &lhs,
const MontgomeryModInt32 &rhs) {
return MontgomeryModInt32(lhs) /= rhs;
}
friend std::istream &operator>>(std::istream &is, MontgomeryModInt32 &rhs) {
return is >> rhs.v, rhs.v = reduce(u64(rhs.v) * r2), is;
}
friend std::ostream &operator<<(std::ostream &os, const MontgomeryModInt32 &rhs) {
return os << rhs.get();
}
constexpr MontgomeryModInt32 pow(i64 y) const {
if ((y %= P - 1) < 0) y += P - 1; // phi(P) = P - 1, assume P is a prime number
MontgomeryModInt32 res(1), x(*this);
for (; y != 0; y >>= 1, x *= x)
if (y & 1) res *= x;
return res;
}
};
这个是计算\(5\cdot 10^7!\mod 998244353\)在duck.ac上的测试结果
Naive Mod : 213689172 Time: 180649
My Montgomery : 213689172 Time: 178217
His Montgomery : 213689172 Time: 152847
这个是计算\(7\cdot 10^7!\mod 998244353\)在duck.ac上的测试结果
Naive Mod : 939830261 Time: 252908
My Montgomery : 939830261 Time: 249476
His Montgomery : 939830261 Time: 213986
还可以看Min25博客里下面的ModInt64板本
下面自己实现的\(\mod 2m\)版本,差不多也是最终版本了,跑起来和hly1204差不多
静态版本
template <uint32_t m> struct Mont2{
private :
using u32=uint32_t;
using i32=int32_t;
using u64=uint64_t;
using i64=int64_t;
static constexpr u32 m2=m<<1;
static constexpr u32 getinv(){
u32 inv=m;
for(int i=0;i<4;++i) inv*=2-inv*m;
return inv;
}
static constexpr u32 inv=-getinv(),r2=-u64(m)%m;
u32 x;
public :
static constexpr u32 reduce(u64 x){
return (x+u64(u32(x)*inv)*m)>>32;
}
Mont2(){ ; }
constexpr Mont2(i32 x):x(reduce(u64(x)*r2)) { }
constexpr Mont2& operator += (const Mont2 &rhs) { return x+=rhs.x-m2,x>=m2&&(x+=m2),*this; }
constexpr Mont2& operator -= (const Mont2 &rhs) { return x-=rhs.x,x>=m2&&(x+=m2),*this; }
constexpr Mont2& operator *= (const Mont2 &rhs) { return x=reduce(u64(x)*rhs.x),*this; }
constexpr friend Mont2 operator + (Mont2 x,const Mont2 &y) { return x+=y; }
constexpr friend Mont2 operator - (Mont2 x,const Mont2 &y) { return x-=y; }
constexpr friend Mont2 operator * (Mont2 x,const Mont2 &y) { return x*=y; }
constexpr i32 get(){
u32 res=reduce(x);
return res>=m?res-m:res;
}
} ;
板子各有优劣.jpg
另外这是Int_To_Montgomery加法的速度,\(7\cdot 10^7\)次加法与类型转换
Naive : : 305907824 80074
My Montgomery : 305907824 109479
My Montgomery2 : 305907824 99896
His Montgomery : 305907824 117449
动态版本
using u32=uint32_t;
using i32=int32_t;
using u64=uint64_t;
using i64=int64_t;
static u32 m,m2,inv,r2,P;
u32 getinv(){
u32 inv=m;
for(int i=0;i<4;++i) inv*=2-inv*m;
return inv;
}
struct Mont{
private :
u32 x;
public :
static u32 reduce(u64 x){
u32 y=(x+u64(u32(x)*inv)*m)>>32;
return i32(y)<0?y+m:y;
}
Mont(){ ; }
Mont(i32 x):x(reduce(u64(x)*r2)) { }
Mont& operator += (const Mont &rhs) { return x+=rhs.x-m2,i32(x)<0&&(x+=m2),*this; }
Mont& operator -= (const Mont &rhs) { return x-=rhs.x,i32(x)<0&&(x+=m2),*this; }
Mont& operator *= (const Mont &rhs) { return x=reduce(u64(x)*rhs.x),*this; }
friend Mont operator + (Mont x,const Mont &y) { return x+=y; }
friend Mont operator - (Mont x,const Mont &y) { return x-=y; }
friend Mont operator * (Mont x,const Mont &y) { return x*=y; }
i32 get(){
u32 res=reduce(x);
return res>=m?res-m:res;
}
};
void Init(int m) {
::m=m,m2=m*2;
inv=-getinv();
r2=-u64(m)%m;
}
这个动态模板计算\(5\cdot 10^7!\mod 998244353\)
Naive Mod : 213689172 494061 (稍微修改了一下暴力的细节。。)
My Montgomery2 : 213689172 152849
不得不说duck.ac真的很nb