BSGS算法
简介
BSGS(baby-step,giant-step)算法可在 \(O(\sqrt p)\) 的时间内求解离散对数问题,所谓离散对数,就是形如下面的同余方程的解:
\[a^x\equiv b\pmod p
\]
其中 \(a\perp p\)
原理
因为 \(a\perp p\) ,由欧拉定理可得 \(a^{\varphi(p)}\equiv 1 \pmod p\) ,而 \(a^0\equiv 1\pmod p\) ,所以 \(0\) 到 \(\varphi(p)\) 包含了模 \(p\) 的所有情况,是一个周期,而更高次幂只是重复周期内的变化
由带余除法可知, \(\forall m,\exist i_1,k_1\text{ s.t. } x=i_1\times m+k_1\) ,那么设 \(i=i_1+1,k=m-k_1\) ,则 \(x=i\times m-k\) 其中 \(1\leq k\leq m\) 则原式可变为:
\[a^{im}\equiv a^k b\pmod p
\]
-
然后我们对右边的 \(k\) 进行枚举,计算出右边的值作为 \(key\) ,\(k\) 作为 \(value\) 放入哈希表
-
再枚举 \(i\) ,计算出左边的值,并以这个值为 \(key\) 在哈希表中查找,如果找到了对应的 \(k\) ,则方程有解 \(x=i\times m-k\) ,若枚举结束了仍没找到,则无解
时间复杂度为 \(O(max(m,\varphi(p)/m))\) ,取 \(m=\lceil\sqrt p\rceil\) 时得到最优复杂度
模板
下面给出的两份代码分别用了自己写的哈希表和STL的map来实现算法,map的查询和插入复杂度都是 \(O(\log n)\) ,所以效率不如自己写的哈希表,好处就是代码量较小
非STL( \(53ms\) )
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int SZ = 100007;
struct hash_table
{
int head[SZ], next[SZ], val[SZ];
ll key[SZ];
int cnt;
void init()
{
cnt = 0;
memset(head, -1, sizeof(head));
}
int hash(ll k)
{
return (k % SZ + SZ) % SZ;
}
int count(ll k)
{
int hu = hash(k);
for(int i = head[hu]; i != -1; i = next[i])
if(key[i] == k)
return i;
return -1;
}
int& operator[](ll k)
{
int idx = count(k);
if(idx != -1) {
return val[idx];
} else {
int hu = hash(k);
key[cnt] = k;
next[cnt] = head[hu];
head[hu] = cnt;
return val[cnt++];
}
}
}tb;
ll qpow(ll a, ll b, ll p)
{
if(b == 0)
return 1;
ll res = qpow(a, b / 2, p);
if(b % 2)
return res * res % p * a % p;
return res * res % p;
}
int bsgs(ll a, ll b, ll p)
{
a %= p, b %= p;
if(a == 0)
return b == 0 ? 1 : -1;
int t = ceil(sqrt(p));
tb.init();
ll pw = b * a % p;
for(int i = 1; i <= t; i++) {
tb[pw] = i;
pw = pw * a % p;
}
pw = qpow(a, t, p);
ll pw1 = pw;
for(int i = 1; i <= t; i++) {
if(tb.count(pw1) != -1)
return i * t - tb[pw1];
pw1 = pw1 * pw % p;
}
return -1;
}
int main()
{
int p, a, b;
cin >> p >> a >> b;
int ans = bsgs(a, b, p);
if(ans == -1)
cout << "no solution" << endl;
else
cout << ans << endl;
return 0;
}
STL( \(185ms\) )
#include<bits/stdc++.h>
#define ll long long
using namespace std;
ll qpow(ll a, ll b, ll p)
{
ll res = 1;
for(; b; b >>= 1) {
if(b & 1)
res = res * a % p;
a = a * a % p;
}
return res;
}
ll exgcd(ll a, ll b, ll &x, ll &y)
{
if(b == 0) {
x = 1;
y = 0;
return a;
}
ll d = exgcd(b, a % b, y, x);
y -= a / b * x;
return d;
}
ll bsgs(ll a, ll b, ll p)
{
a %= p, b %= p;
if(a == 0)
return b == 0 ? 1 : -1;
map<int, int> hash;
int t = ceil(sqrt(p));
ll pw = b * a % p;
for(int i = 1; i <= t; i++) {
hash[pw] = i;
pw = pw * a % p;
}
pw = qpow(a, t, p);
ll pw1 = pw;
for(int i = 1; i <= t; i++) {
if(hash.count(pw1))
return i * t - hash[pw1];
pw1 = pw1 * pw % p;
}
return -1;
}
int main()
{
int p, a, b;
cin >> p >> a >> b;
int ans = bsgs(a, b, p);
if(ans == -1)
cout << "no solution" << endl;
else
cout << ans << endl;
return 0;
}