128MTT 学习笔记

标题是我乱起的名字

在做某题时受到了启发,想出了一种之前没听说过的 MTT,在某谷上一问发现有人和我想的一样,立马去学了。

这种方法,我叫它 128MTT,它用到了科技 __int128。主要思想就是找一个 \(10^{27}\) 以上的大 NTT 模数,全程使用 __int128 做 NTT。这种方法有一个好处就是支持的变换长度远大于传统 NTT/FFT(大概是它们的 \(2^{50}\sim 2^{100}\) 倍)。

关于 __int128 的模乘:我使用的是 Montgomery 模乘(跑得飞快),实测拆系数乘常数极大,慢于三模 NTT。

大常数的拆系数乘(最快的一版,为效率选用原根为 \(11\) 的小模数 \(39 \times 2^{70}+1\)

// MOD = 39*2**70+1
// u128, UP 等定义见最下方
u128 mul(u128 x, u128 y){
	u128 ans = 0;
	for(int i:{0, 1}){
		ans = ((ans<<48) + y*((x>>48)&0x0000FfffFfffFfffUll)) % MOD;
		x <<= 48;
	}
	return ans;
}

这样,我们就可以做一次 NTT 了,吊打各种 \(3\)\(4\) 次的 FFT。(bushi)

这样,我们就可以类 NTT 来做 MTT 了。原理与 NTT 完全相同。

关于 NTT 模数:

for i in range(1000):
    if miller_rabin(2**76*i+1): print(i, 2**76*i+1)

其中 Miller-Rabin 可以在 OI-Wiki 上找到。

推荐 \(25 \times 2^{78}+1\)\(7 \times 2^{120}+1\),它们的最小原根都是 \(3\),且好记。本文用的是 \(25 \times 2^{78}+1\).

例题:洛谷 P4245

#include <iostream>
#include <fstream>
#include <cassert>
#include <algorithm>
#include <vector>
#define UP(i,s,e) for(auto i=s; i!=e; ++i)
using std::cin;
using std::cout;
namespace BITS{ // }{{{
using i64 = long long;
using u64 = unsigned long long;
using i128 = __int128_t;
using u128 = __uint128_t;
struct u256{
	u128 hi, lo;
	constexpr u256(u128 hi, u128 lo):hi(hi), lo(lo){}
	constexpr u256 &operator=(u256 x){ hi=x.hi, lo=x.lo; return *this; }
	constexpr u256 operator+=(u256 x){ 
		if(((lo>>1)+(x.lo>>1)+(lo&1&x.lo)) >> 63) hi++;
		hi += x.hi; lo += x.lo;
		return *this;
	}
};
constexpr u256 mul128(u128 x, u128 y){
	u64 xl(x), xh(x>>64), yl(y), yh(y>>64);
	u128 mid(u128(xl)*yh), anslo(u128(xl)*yl);
	u64 mh(mid>>64), ml(mid);
	mid = u128(yl)*xh + ml + (anslo >> 64);
	return u256(u128(xh)*yh + mh + (mid>>64), 
			(mid<<64) | u64(anslo));
}
constexpr u128 str_to_u128(char const *s){
	u128 x(0);
	while(*s){ x = x*10+(*s)-'0'; s++; }
	return x;
}
char* u128_to_string(u128 x){ // u128 < 10**39
	char buf[40]={0}, *p = buf+38;
	do { 
		*p = x%10 + '0';
		x/=10;
		p--;
	} while(x>0);
	return p+1;
}
} // {}}}
BITS::u64 gethi(BITS::u128 x){ return x>>64; }
BITS::u64 getlo(BITS::u128 x){ return x; }
template<typename T>
constexpr T exgcd(T a, T b, T &x, T &y){
	if(b == 0) return x=1, y=0, void();
	exgcd(b, a%b, y, x); y-=a/b*x;
}
template<typename T>
constexpr T qpow(T x, BITS::u128 tim){
	T a(1);
	while(tim){
		if(tim&1) a*=x;
		x*=x;
		tim>>=1;
	}
	return a;
}
namespace Montgomery{ // }{{{
using namespace BITS;
constexpr u128 MOD = str_to_u128("7555786372591432341913601"); // MOD must be < 2**63
constexpr int MOD_G = 3;
constexpr u128 NIMOD = -qpow(MOD, (u128(1)<<127)-1);
constexpr int lgR = 128;
constexpr u128 reduce(u256 x){
	x += mul128(x.lo*NIMOD, MOD); // mod 2**lgR
	assert(x.lo == 0);
	return x.hi >= MOD ? x.hi-MOD : x.hi;
}
constexpr u128 R1 = reduce(mul128((-MOD)%MOD, (-MOD)%MOD));
constexpr u128 R2_(){
	u256 ans = mul128(R1, R1); u128 anslo(ans.lo%MOD);
	while(ans.hi > 0){
		ans.hi %= MOD;
		ans = mul128(ans.hi, R1);
		anslo += ans.lo%MOD;
		anslo %= MOD;
	}
	return anslo;
}
constexpr u128 R2 = R2_();
constexpr u128 R3 = reduce(mul128(R2, R2));
constexpr u128 mmul(u128 x, u128 y){ return reduce(mul128(x,y)); }
class m128{ // Montgomery unsigned 128 bit modulo MOD
	private:
	u128 val;
	public:
	m128(){}
	m128 &operator=(m128 x){ val = x.val; return *this; }
	m128 &operator=(u128 x){ val = mmul(x, R2); return *this; }
	m128(u128 x){ *this = x; }
	m128 &operator+=(m128 x){ val += x.val; val = val>=MOD ? val-MOD :val; return *this; }
	m128 operator+(m128 x){ x+=*this; return x; }
	m128 &operator-=(m128 x){ if(val < x.val) val+=MOD; val -= x.val; return *this; }
	m128 operator-(m128 x){ m128 t=*this; t-=x; return t; }
	m128 &operator*=(m128 x){ val = mmul(val, x.val); return *this; }
	m128 operator*(m128 x){ x*=*this; return x; }
	u128 get(){ return mmul(val, 1); }
};
} // {}}}
namespace Poly{ // }{{{
using CC = Montgomery::m128;
using Montgomery::MOD_G;
using Montgomery::MOD;
void change(CC *y, int len){
	static std::vector<int> rev; static int llen = 0;
	if(llen != len){
		llen = len;
		rev.reserve(len);
		rev[0] = 0;
		UP(i, 1, len){
			rev[i] = rev[i>>1] >> 1;
			if(i&1) rev[i] |= len>>1;
		}
	}
	UP(i, 0, len) if(i<rev[i]) std::swap(y[i], y[rev[i]]);
}
void ntt(CC *y, int len, bool idft){
	change(y, len);
	for(int h=2; h<=len; h<<=1){
		CC wn = qpow(CC(MOD_G), (MOD-1)/h);
		for(int j=0; j<len; j+=h){
			CC w(1);
			UP(k, j, j+h/2){
				CC u = y[k], v = w * y[k+h/2];
				y[k] = u+v, y[k+h/2] = u-v;
				w = w * wn;
			}
		}
	}
	if(idft){
		std::reverse(y+1, y+len);
		CC invlen = qpow(CC(len), MOD-2);
		UP(i, 0, len){
			y[i] *= invlen;
		}
	}
}
void dot(CC *y, CC *z, int len){ UP(i, 0, len) y[i] *= z[i]; }
void polymul(CC *y, CC *z, int len){
	ntt(y, len, false);
	ntt(z, len, false);
	dot(y, z, len);
	ntt(y, len, true);
}
} // {}}}
namespace m{ // }{{{
int in, im, ip;
Poly::CC ia[1<<18], ib[1<<18];
void work(){
	cin >> in >> im >> ip;
	UP(i, 0, in+1){
		int x;
		cin >> x;
		ia[i] = x;
	}
	UP(i, 0, im+1){
		int x;
		cin >> x;
		ib[i] = x;
	}
	int len = 1;
	while(len < im+in+1) len<<=1;
	Poly::polymul(ia, ib, len);
	UP(i, 0, in+im+1){
		cout << BITS::u64(ia[i].get()%ip) << ' ';
	}
}
} // {}}}
int main(){
#if ONLINE_JUDGE
	std::ios::sync_with_stdio(0); std::cin.tie(0);
#endif
   	m::work(); return 0; 
}

评测记录

跑的比这篇慢多了,果然人傻常大 =(

upd: 是没预处理单位根的问题

posted @ 2023-07-24 16:28  383494  阅读(121)  评论(0编辑  收藏  举报