[学习笔记]BSGS及其扩展

BSGS

概述

BSGS算法是用于求解同余方程\(a^x \equiv b \ (mod \ p)\)(其中\(a,p\)互质)的最小自然数解(或正整数解)高效算法,复杂度\(O(\sqrt p)\)

思路

首先由欧拉定理,有\(a^{\varphi (p)} \equiv 1 \ (mod \ p)\),于是\(a^{x} \equiv a^{x \mod \varphi (p)} \ (mod \ p)\)

所以\(a^x \equiv b \ (mod \ p)\)如果有解,一定存在一个满足\(x < \varphi (p) < p\)的解,我们只需考虑\(x < p\)的部分。

对于\(a^x \equiv b \ (mod \ p)\), 令\(x = i \cdot m - j\), 其中\(i = ceil(p / m), j < m\)

则原方程可化为\(a^{i \cdot m - j} \equiv b \ (mod \ p)\),或\(a^{i \cdot m} \equiv ba^{j} \ (mod \ p)\)

注意到\(i\)\(j\)的取值均不超过\(m\)个,所以可以暴力枚举\(j\),存下右边的值对应的\(j\),再暴力枚举\(i\),查询右边能否取到\(a^{im}\)。如果能,那么答案就是对应的\(i \cdot m - j\),否则就继续枚举。

如果枚举完所以\(i,j\)还没找到答案,说明无解。

容易发现上面的过程中\(i \cdot m - j\) 不会取到0,所以如果求自然数解,需特判\(b = 1\)的情况。

上面的过程中枚举\(i\)的复杂度\(O(p / m)\),枚举\(j\)的复杂度\(O(m)\)。为了使总复杂的最小,应取\(m = \sqrt{p}\),总复杂度\(O(\sqrt{p})\)

存等号右边的取值时,可以手写hash或者直接用map,用map复杂度会多一个\(log\)

实践中可以map和unordered_map都试一下,万一出题人不卡呢

代码在最下方。

扩展BSGS

概述

普通的BSGS只能解决\(a,p\)互质的情况,那么\(a,p\)不互质的情况怎么办呢?

思路

\(a^x \equiv b \ (mod \ p)\)可以表示成\(a^x + k \cdot p = b, k为整数\),假设\(gcd(a, p) = d\),上面的式子有解的条件是\(d | b\),并且可化为\(a^{x - 1} \frac{a}{d} + k \cdot \frac{p}{d} = \frac{b}{d}\),即解方程\(a^{x - 1} \frac{a}{d} \equiv \frac{b}{d} \ (mod \ \frac{p}{d})\)

于是我们可以不断的进行上面的转化,直到\(gcd(a, \frac{p}{\prod_{i}{d_i}}) = 1\),这时式子变为\(a^{x - k} \frac{a^k}{\prod_{i = 1}^{k}{d_i}} \equiv \frac{b}{\prod_{i = 1}^{k}{d_i}} \ (mod \ \frac{p}{\prod_{i}{d_i}})\),把\(\frac{a^k}{\prod_{i=1}^{k}{d_i}}\)除到右边去(乘逆元),就变成了BSGS能解的问题,原问题的答案就是BSGS的结果加上\(k\),BSGS无解则原方程无解。

如果中间某一步发现\(\frac{a^k}{\prod_{i = 1}^{k}{d_i}} \equiv \frac{b}{\prod_{i = 1}^{k}{d_i}} \ (mod \ \frac{p}{\prod_{i}{d_i}})\)了,答案就是\(k\)

代码:

//qpow(a, b, p)=a的b次幂模p
//inverse(a, p)=a在模p意义下的逆元
struct BSGS {
    std::map<LL, LL> hash;
    LL solve(LL a, LL b, LL p) {
        a %= p, b %= p;
        if (b == 1) return 0;
        hash.clear();
        LL m = sqrt(p) + 1;
        for (LL i = 0, j = b; i < m; ++i, j = j * a % p)
            hash[j] = i;
        LL A = qpow(a, m, p);
        for (LL i = 1, j = A; i <= m; ++i, j = j * A % p)
            if (hash.find(j) != hash.end()) return i * m - hash[j];
        return -1;
    }
    LL solve_ex(LL a, LL b, LL p) {
        a %= p, b %= p;
        if (b == 1 || p == 1) return 0;
        LL A = 1, k = 0;
        for (LL d = gcd(a, p); d != 1; d = gcd(a, p)) {
            if (b % d) return -1;
            p /= d, b /= d, ++k;
            A = a / d * A % p;
            if (A == b) return k;
        }
        LL res = solve(a, b * inverse(A, p) % p, p);
        return (res == -1 ? -1 : res + k);
    }
} bsgs;
posted @ 2021-07-22 10:10  Rhein_E  阅读(112)  评论(0编辑  收藏  举报