Luogu 4777 【模板】扩展中国剩余定理(EXCRT)

复习模板。

两两合并同余方程

$x\equiv C_{1} \ (Mod\ P_{1})$

$x\equiv C_{2} \ (Mod\ P_{2})$

把它写成不定方程的形式:

$x = C_{1} + P_{1} * y_{1}$

$x = C_{2} + P_{2} * y_{2}$

发现上下两式都等于$x$

所以$C_{1} + P_{1} * y_{1} = C_{2} + P_{2} * y_{2}$

稍微移项一下,有$P_{1} * y_{1} + P_{2} * (-y_{2}) = C_{2} - C_{1}$。

发现这个式子很眼熟,其实就是一个不定方程,那么根据裴蜀定理,要使此方程有解需要满足$gcd(P_{1}, P_{2}) | (C_{2} - C_{1})$,否则这一堆同余方程就无解了。

我们有$exgcd$算法可以解这个$y_{1}$,解出来之后把它回代到上式里面去,就得到了合并之后的同余方程:$x\equiv C_{1} + P_{1} * y_{1} \ (Mod\ lcm(P_{1}, P_{2}))$。

根据【NOI2018】屠龙勇士的经验,当$P == 1$的时候,这个同余方程其实是没什么意义的,但是把它代进去算就会挂掉,所以需要特判掉。

发现乘法会溢出,需要使用龟速乘,按照我自己的sb写法,要注意在龟速乘的时候保证$y \geq 0$。

时间复杂度$O(nlog^{2}n)$,然而欧几里得算法的$log$基本上都跑不满。

Code:

#include <cstdio>
#include <cstring>
using namespace std;
typedef long long ll;

const int N = 1e5 + 5;

int n;
ll rest[N], mod[N];

template <typename T>
inline void read(T &X) {
    X = 0; char ch = 0; T op = 1;
    for(; ch > '9' || ch < '0'; ch = getchar())
        if(ch == '-') op = -1;
    for(; ch >= '0' && ch <= '9'; ch = getchar())
        X = (X << 3) + (X << 1) + ch - 48;
    X *= op;
}

inline ll mul(ll x, ll y, ll P) {
    ll res = 0;
    for(; y > 0; y >>= 1) {
        if(y & 1) res = (res + x) % P;
        x = (x + x) % P;
    }
    return res;
}

ll exgcd(ll a, ll b, ll &x, ll &y) {
    if(!b) {
        x = 1, y = 0;
        return a;
    }
    
    ll res = exgcd(b, a % b, x, y), tmp = x;
    x = y, y = tmp - (a / b) * y;
    return res;
}

inline ll exCrt() {
    ll P, c, x, y, d, t; int pos = 0;
    for(int i = 1; i <= n; i++) 
        if(mod[i] != 1) {
            pos = i, P = mod[i], c = rest[i];
            break;
        }
    for(int i = pos + 1; i <= n; i++) {
        if(mod[i] == 1) continue;
        d = exgcd(P, mod[i], x, y);
        ll r = (((rest[i] - c)) % mod[i] + mod[i]) % mod[i];
        t = mul(x, r / d, mod[i] / d);
//        t = (rest[i] - c) / d * x;
//        t = (t % (mod[i] / d) + (mod[i] / d)) % (mod[i] / d);
//        c = (c + mul(P, t, P / d * mod[i])) % (P / d * mod[i]);
        c = c + P * t;
        P = P / d * mod[i];
        c = (c % P + P) % P;
    }
    return (c % P + P) % P;
}

int main() {    

    read(n);
    for(int i = 1; i <= n; i++) 
        read(mod[i]), read(rest[i]);
    
    printf("%lld\n", exCrt());
    return 0;
}
View Code

 

posted @ 2018-09-11 19:30  CzxingcHen  阅读(207)  评论(0编辑  收藏  举报