【数学】扩展中国剩余定理
问题描述
给定 \(n\) 个正整数 \(a_i\) 和 \(n\) 个非负整数 \(b_i\),求解关于 \(x\) 的线性同余方程组:
\[\begin{cases}
x\equiv b_1\pmod {a_1}\\
x\equiv b_2\pmod {a_2}\\
\cdots\\
x\equiv b_n\pmod {a_n}\\
\end{cases}
\]
其中 \(\exist i,j\in[1,n],i\ne j,\gcd(a_i,a_j)>1\)。
定理内容
前置芝士:
- \(\rm exgcd\)
- 不需要了解 \(\rm CRT\)。
首先,对于普通 \(\rm CRT\),答案就是
\[x=(\sum\limits_{i=1}^n b_im_iM_i)\bmod m
\]
其中 \(m=\prod\limits_{i=1}^{n}a_i,m_i=\dfrac{m}{a_i},M_i=m_i^{-1}\bmod a_i\)。
但是这次 \(a_i\) 不一定互质,所以 \(m_i\) 也不一定和 \(a_i\) 互质,那么 \(M_i\) 就有可能不存在。
所以整个 \(\rm CRT\) 就废掉了,而且这个漏洞是没法补的,只能重头建起。
我们要用到扩展中国剩余定理(\(\rm Extended\ Chinese\ Remainder\ Theorem,CRT\))。
思路是将同余方程两两合并。
假设现在处理到第 \(i\) 个同余方程:
\[\begin{cases}
x\equiv B\pmod A&(1)\\
x\equiv b\pmod a&(2)
\end{cases}
\]
其中 \((1)\) 是前 \((i-1)\) 个方程合并后的结果,\((2)\) 是第 \(i\) 个方程,我们现在要把它们合并成一个。
将上面两个写成不定方程的形式:
\[\begin{cases}
x+Ap=B\\
x+aq=b
\end{cases}
\]
所以
\[Ap-aq=B-b\\
Ap+a(-q)=B-b
\]
就可以用 \(\rm exgcd\) 算了。
解出 \(p,q\) 后将 \(p\) 回代:\(x=B-Ap\)。这个 \(x\) 就是同时满足 \((1)(2)\) 的。
注意此时模数变成了 \(\operatorname{lcm}(A,a)\),即前 \(i\) 个方程合并后就是
\[x\equiv B-Ap\pmod{\operatorname{lcm}(A,a)}
\]
一开始令 \(A=a_1,B=b_1\),从 \(2\) 开始合并即可。
时间复杂度为 \(O(n\log a)\)。
参考代码
//18 = 9 + 9 = 18.
#include <iostream>
#include <cstdio>
#define Debug(x) cout << #x << "=" << x << endl
typedef long long ll;
using namespace std;
ll x, y;
ll exgcd(ll a, ll b)
{
if (!b)
{
x = 1, y = 0;
return a;
}
ll gcd = exgcd(b, a % b);
ll tmp = x;
x = y;
y = tmp - a / b * y;
return gcd;
}
ll qmul(ll a, ll b, ll p)
{
if (b < 0)
{
a = -a;
b = -b;
}
ll base = a, ans = 0;
while (b)
{
if (b & 1)
{
ans = (ans + base) % p;
}
base = (base + base) % p;
b >>= 1;
}
return (ans + p) % p;
}
ll A, B;
int merge(ll a, ll b)
{
ll Gcd = exgcd(A, a), c = B - b;
if (c % Gcd)
{
return -1;
}
c /= Gcd;
x = (qmul(x, c, a) + a) % a;
ll Lcm = A / Gcd * a;
B = (B - qmul(A, x, Lcm) + Lcm) % Lcm;
A = Lcm;
return 1;
}
const int MAXN = 1e5 + 5;
ll a[MAXN], b[MAXN];
ll exCRT(int n)
{
A = a[1], B = b[1];
for (int i = 2; i <= n; i++)
{
if (merge(a[i], b[i]) == -1)
{
return -1;
}
}
return B;
}
int main()
{
int n;
scanf("%d", &n);
for (int i = 1; i <= n; i++)
{
scanf("%lld%lld", a + i, b + i);
}
printf("%lld\n", exCRT(n));
return 0;
}