【数学】扩展中国剩余定理

问题描述

给定 \(n\) 个正整数 \(a_i\)\(n\) 个非负整数 \(b_i\),求解关于 \(x\) 的线性同余方程组:

\[\begin{cases} x\equiv b_1\pmod {a_1}\\ x\equiv b_2\pmod {a_2}\\ \cdots\\ x\equiv b_n\pmod {a_n}\\ \end{cases} \]

其中 \(\exist i,j\in[1,n],i\ne j,\gcd(a_i,a_j)>1\)

定理内容

前置芝士:

  • \(\rm exgcd\)
  • 不需要了解 \(\rm CRT\)

首先,对于普通 \(\rm CRT\),答案就是

\[x=(\sum\limits_{i=1}^n b_im_iM_i)\bmod m \]

其中 \(m=\prod\limits_{i=1}^{n}a_i,m_i=\dfrac{m}{a_i},M_i=m_i^{-1}\bmod a_i\)

但是这次 \(a_i\) 不一定互质,所以 \(m_i\) 也不一定和 \(a_i\) 互质,那么 \(M_i\) 就有可能不存在。

所以整个 \(\rm CRT\) 就废掉了,而且这个漏洞是没法补的,只能重头建起。

我们要用到扩展中国剩余定理(\(\rm Extended\ Chinese\ Remainder\ Theorem,CRT\))。

思路是将同余方程两两合并。

假设现在处理到第 \(i\) 个同余方程:

\[\begin{cases} x\equiv B\pmod A&(1)\\ x\equiv b\pmod a&(2) \end{cases} \]

其中 \((1)\) 是前 \((i-1)\) 个方程合并后的结果,\((2)\) 是第 \(i\) 个方程,我们现在要把它们合并成一个。

将上面两个写成不定方程的形式:

\[\begin{cases} x+Ap=B\\ x+aq=b \end{cases} \]

所以

\[Ap-aq=B-b\\ Ap+a(-q)=B-b \]

就可以用 \(\rm exgcd\) 算了。

解出 \(p,q\) 后将 \(p\) 回代:\(x=B-Ap\)。这个 \(x\) 就是同时满足 \((1)(2)\) 的。

注意此时模数变成了 \(\operatorname{lcm}(A,a)\),即前 \(i\) 个方程合并后就是

\[x\equiv B-Ap\pmod{\operatorname{lcm}(A,a)} \]

一开始令 \(A=a_1,B=b_1\),从 \(2\) 开始合并即可。

时间复杂度为 \(O(n\log a)\)

参考代码

//18 = 9 + 9 = 18.
#include <iostream>
#include <cstdio>
#define Debug(x) cout << #x << "=" << x << endl
typedef long long ll;
using namespace std;

ll x, y;

ll exgcd(ll a, ll b)
{
	if (!b)
	{
		x = 1, y = 0;
		return a;
	}
	ll gcd = exgcd(b, a % b);
	ll tmp = x;
	x = y;
	y = tmp - a / b * y;
	return gcd;
}

ll qmul(ll a, ll b, ll p)
{
	if (b < 0)
	{
		a = -a;
		b = -b;
	}
	ll base = a, ans = 0;
	while (b)
	{
		if (b & 1)
		{
			ans = (ans + base) % p;
		}
		base = (base + base) % p;
		b >>= 1;
	}
	return (ans + p) % p;
}

ll A, B;

int merge(ll a, ll b)
{
	ll Gcd = exgcd(A, a), c = B - b;
	if (c % Gcd)
	{
		return -1;
	}
	c /= Gcd;
	x = (qmul(x, c, a) + a) % a;
	ll Lcm = A / Gcd * a;
	B = (B - qmul(A, x, Lcm) + Lcm) % Lcm;
	A = Lcm;
	return 1;
}

const int MAXN = 1e5 + 5;

ll a[MAXN], b[MAXN];

ll exCRT(int n)
{
	A = a[1], B = b[1];
	for (int i = 2; i <= n; i++)
	{
		if (merge(a[i], b[i]) == -1)
		{
			return -1;
		}
	}
	return B;
}

int main()
{
	int n;
	scanf("%d", &n);
	for (int i = 1; i <= n; i++)
	{
		scanf("%lld%lld", a + i, b + i);
	}
	printf("%lld\n", exCRT(n));
	return 0;
}
posted @ 2021-12-03 13:58  mango09  阅读(26)  评论(0编辑  收藏  举报
-->