【数学】中国剩余定理
给定 \(n\) 个正整数 \(a_i\) 和 \(n\) 个非负整数 \(b_i\),求解关于 \(x\) 的线性同余方程组:
\[\begin{cases}
x\equiv b_1\pmod {a_1}\\
x\equiv b_2\pmod {a_2}\\
\cdots\\
x\equiv b_n\pmod {a_n}\\
\end{cases}
\]
其中 \(\forall i,j\in[1,n],i\ne j\) 满足 \(a_i,a_j\) 互质。
首先令 \(m=\prod\limits_{i=1}^{n}a_i\)。
如果有一个解 \(x_1\),那么通解就是 \(x=x_1+tm\)(\(t\) 为整数)。
所以我们只需要求出最小解即可,那么这个最小解一定 \(< m\)。
中国剩余定理(\(\rm Chinese\ Remainder\ Theorem, CRT\))给出了构造解的方法:
\(x=(\sum\limits_{i=1}^n b_im_iM_i)\bmod m\)
其中 \(m_i=\dfrac{m}{a_i}\),\(M_i\) 为 \(m_i\) 在模 \(a_i\) 意义下的逆元。
则 \(\forall i,j\in[1,n],i\ne j\):
\[\begin{aligned}
b_jm_jM_j
&\equiv b_j\dfrac{m}{a_j}M_j\\
&\equiv b_j\dfrac{a_1 a_2\cdots a_i\cdots a_n}{a_j}M_j\\
&\equiv b_j\cdot 0\cdot M_j\\
&\equiv 0
\pmod {a_i}
\end{aligned}
\]
而 \(\forall i\in[1,n]\):
\[\begin{aligned}
b_im_iM_i
&\equiv b_i(m_iM_i)\\
&\equiv b_I(m_im_i^{-1})\\
&\equiv b_i\cdot 1\\
&\equiv b_i
\pmod {a_i}
\end{aligned}
\]
所以 \(\forall i\in[1,n]\):
\[\begin{aligned}
\sum\limits_{i=1}^n b_im_iM_i
&\equiv \sum\limits_{j=1,j\ne i}^n b_jm_jM_j+b_im_iM_i\\
&\equiv 0+b_i\\
&\equiv b_i
\pmod {a_i}
\end{aligned}
\]
这就是一组满足条件的解了。
接下来证明在模 \(m\) 的意义下只有一个解。
假设有 \(x\equiv y\pmod m\),那么 \(\forall i\in[1,n]\),有
\[\begin{aligned}x &\equiv x\bmod m\\ &\equiv y \pmod {a_i} \end{aligned} \]所以 \(x\) 和 \(y\) 是同一个解。
时间复杂度为 \(\mathcal{O}(n\log n)\)。
\(\text{Code}\)
//18 = 9 + 9 = 18.
#include <iostream>
#include <cstdio>
typedef long long ll;
using namespace std;
const int MAXN = 15;
ll x, y;
void exgcd(ll a, ll b)
{
if (!b)
{
x = 1, y = 0;
return;
}
exgcd(b, a % b);
ll tmp = x;
x = y;
y = tmp - a / b * y;
}
ll inv(ll a, ll p)
{
exgcd(a, p);
x = (x % p + p) % p;
return x;
}
ll qmul(ll a, ll b, ll p)
{
ll base = a, ans = 0;
while (b)
{
if (b & 1)
{
ans = (ans + base) % p;
}
base = (base + base) % p;
b >>= 1;
}
return ans;
}
int a[MAXN], b[MAXN];
ll CRT(int n)
{
ll m = 1;
for (int i = 1; i <= n; i++)
{
m *= a[i];
}
ll ans = 0;
for (int i = 1; i <= n; i++)
{
ll mi = m / a[i];
ll Mi = inv(mi, a[i]);
ans = (ans + qmul(b[i], qmul(mi, Mi, m), m)) % m;
}
return ans;
}
int main()
{
int n;
scanf("%d", &n);
for (int i = 1; i <= n; i++)
{
scanf("%d%d", a + i, b + i);
}
printf("%lld\n", CRT(n));
return 0;
}
双倍经验:P3868 [TJOI2009] 猜数字