离散对数 (BSGS 分块 哈希表) [2022.5.1]
离散对数
题目描述
对于 $A^x \equiv 𝐵\ \ (mod\ P)$,找到 x 的最小正整数解,无解则输出一行 "No solution"。
输入格式
第一行一个整数 n,表示有 n 组数据。
接下来 n 行,每行三个整数 A, B, P。输出格式
输出 n 行,每行一个整数表示答案,或"No solution"表示找不到正整数解。
样例输入
3
2 1 5
2 3 5
3 3 5样例输出
0
3
1数据范围与约定
对于前 20%的数据,保证 $P<=1000$;
对于前 70%的数据,保证 $n<=5$;
对于 100%的数据,保证 $1<=n<=200,1<=A, B<P,2<=P<2^{31}$ 且为质数
解题思路
直接上暴力的做法
考虑到\(A ^ x\ \equiv\ A ^ {x\ mod\ \phi(p)}\ \ (mod\ p)\)
那么可以枚举x的值,之后的BSGS大致也可以理解为对这种枚举的优化
然后就是正解,分块+哈希表
将p进行分块,设分块后的长度为len
这样就可以把\(A ^ x\)表示为\(A ^ {i * len - j}\)的形式
然后移项可以得到\(A ^ {i\ *\ len} \equiv B\ *\ A ^ {j}\ \ (mod\ p)\)
然后把等式右侧的len个值存进哈希表里
等式左侧枚举i,复杂度为\(\mathrm{O}(\sqrt{p})\)
每次枚举得到的值在右端的哈希表中查找是否存在即可
代码里还有用 unodered_map 写的版本 (时间大概是手写哈希桶的二倍)
使用了内置函数 mp.find() 或者 mp.count()
以及 mp.clear()
code:
点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int mod = 100005 - 8;
const int base = 233;
int a, b, p, n;
struct node {
int val, cnt, id;
//val 记录桶这个位置上存的值
//cnt 记录这个值被插入了几次
//id 记录本题需要用到的j
}hs[100005];
//unordered_map<int, int> mp;
int find(int k) {
int res = k % mod;//简单的哈希映射
while (hs[res].val != k && hs[res].cnt != 0) {
res += base;
if (res >= mod) res -= mod;
}
return hs[res].val == k ? hs[res].id : -1;
}
void insert(int k, int key) {//哈希表插入,可以和插入写在一起
int res = k % mod;
while (hs[res].val != k && hs[res].cnt != 0) {
res += base;//以base来解决冲突问题
if (res >= mod) res -= mod;
}
hs[res].cnt++;
hs[res].val = k;
hs[res].id = key;
}
int fast_pow(int base, int x) {
int res = 1;
while (x) {
if (x & 1) res = 1ll * res * base % p;
base = 1ll * base * base % p;
x >>= 1;
}
return res;
}
int main() {
// freopen("log.in", "r", stdin);
// freopen("log.out", "w", stdout);
scanf("%d", &n);
while (n--) {
// mp.clear();
memset(hs, 0, sizeof(hs));
scanf("%d %d %d", &a, &b, &p);
int len = sqrt(p) + 1, tag = 1;
for (int i = 0; i <= len; i++) {
if (i) tag = 1ll * tag * a % p;
// mp[1ll * tag * b % p] = i;
insert(1ll * tag * b % p, i);
}
bool solv = 0;
int fl = fast_pow(a, len), ans = INT_MAX;
tag = 1;
for (int i = 1; (i - 1) * len <= p - 2; i++) {
tag = 1ll * tag * fl % p;
int k = find(tag);
if (~k) {
solv = 1;
ans = i * len - k;
break;
}
// if (mp.find(tag) != mp.end()) {
// solv = 1;
// ans = min(i * len - mp[tag], ans);
// break;
// }
}
if (solv) printf("%d\n", ans);
else printf("No solution\n");
}
}