【数学】中国剩余定理
中国剩余定理
一元线性同余方程组
\( x\equiv r_1 (\mod m_1)\\ x\equiv r_2 (\mod m_2)\\ x\equiv r_3 (\mod m_3)\\ ...\\ x\equiv r_n (\mod m_n)\\ \)
其中,对于任意的 \(i\neq j\) , \(gcd(m_i,m_j)=1\) 。
算法流程:
计算所有模的乘积 \(M=\prod\limits_{i=1}^n m_i\)
对于第 \(i\) 个方程,计算新的模 \(M_i=\frac{M}{m_i}\)
计算新的模 \(M_i\) 在模 \(m_i\) 意义下的逆元 \(inv(M_i)\)
方程的唯一解为 \(x\equiv \sum\limits_{i=1}^n r_i*M_i*inv(M_i) (\mod M)\)
中国剩余定理并不会无解。所有模数的乘积不能超过long long的范围,后面的乘法溢出可以用qmul避免。
验证链接:https://www.luogu.com.cn/problem/P1495
namespace CRT {
ll qmul(ll a, ll b, ll mod) {
ll res = 0;
while(b) {
if(b & 1)
res = (res + a) % mod;
a = (a + a) % mod, b >>= 1;
}
return res;
}
ll exgcd(ll a, ll b, ll &x, ll &y) {
if(b == 0) {
x = 1, y = 0;
return a;
}
ll d = exgcd(b, a % b, x, y), t;
t = x, x = y, y = t - a / b * y;
return d;
}
ll inv(ll a, ll m) {
ll x, y, d = exgcd(a, m, x, y);
if(d != 1)
// no solution
return -1;
ll x0 = (x % m + m) % m;
// solution is: x = x0
return x0;
}
// x = r_i mod m_i
pll crt(ll *r, ll *m, int n) {
ll M = 1, R = 0;
for(int i = 1; i <= n; ++i)
M *= m[i];
for(int i = 1; i <= n; ++i) {
ll Mi = M / m[i], invMi = inv(Mi, m[i]);
R = (R + qmul(r[i], qmul(Mi, invMi, M), M)) % M;
}
// solution is: x = R mod M
return pll(R, M);
}
}
扩展中国剩余定理
扩展中国剩余定理不再要求模数两两互质,实际上是依次合并若干个同余方程的过程,发现矛盾则报告无解。返回合并成功后的方程 \(x = R \mod M\) ,根据这个公式可以知道通解。在通过最小非负整数解来构造某个特解的过程中,配合倍增和同余方程来确定范围。
验证链接:https://www.luogu.com.cn/problem/P4777
namespace exCRT {
ll qmul(ll a, ll b, ll mod) {
ll res = 0;
while(b) {
if(b & 1)
res = (res + a) % mod;
a = (a + a) % mod, b >>= 1;
}
return res;
}
ll exgcd(ll a, ll b, ll &x, ll &y) {
if(b == 0) {
x = 1, y = 0;
return a;
}
ll d = exgcd(b, a % b, x, y), t;
t = x, x = y, y = t - a / b * y;
return d;
}
pll merge(ll r1, ll m1, ll r2, ll m2) {
ll x, y, t = ((r2 - r1) % m2 + m2) % m2, d = exgcd(m1, m2, x, y);
if(t % d != 0)
// no solution
return pll(-1, -1);
ll R = (r1 + qmul(x, t / d, m2 / d) * m1), M = m1 * (m2 / d), R = (R % M + M) % M;
// solution is: x = R mod M
return pll(R, M);
}
// x = r_i mod m_i
pll excrt(ll *r, ll *m, int n) {
ll M = m[1], R = (r[1] % M + M) % M;
for(int i = 2; i <= n; ++i) {
pll res = merge(R, M, r[i], m[i]);
if(res.first == -1 && res.second == -1)
// no solution
return res;
R = res.first, M = res.second;
}
// solution is: x = R mod M
return pll(R, M);
}
}