BSGS学习笔记

BSGS算法

BSGS算法使用来求解\(y\)的方程

\[x ^ y \equiv z\pmod p \]

其中\(gcd(x, p) = 1\), 我们将\(y\)写做一个\(am - b\)的形式, 其中\(a \in (1, m + 1]\), \(b \in [0, m)\)

那么这样, 原式就变成了

\[\begin{aligned} x^{am - b} &\equiv z \pmod p\\ x^{am} &\equiv z * x ^ b \pmod p \end{aligned} \]

我们枚举每一个\(b\), 将\(z * x ^ b\)存进\(hash\)或者\(map\)之类的东西

之后对于左边枚举每一个\(a\), 如果\(x ^ {am} \% p\)\(hash\)或者\(map\)中, 答案就为\(a * m - mp[x ^ {am} \% p]\)

复杂度为\(O(max(m, p / m))\), 易证得\(m\)\(\sqrt{p}\)的时候最优

Code

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <vector>
#include <cmath>
#include <map>
#define itn int
#define reaD read
using namespace std;

int p, x, y, s, m;
map<int, int> mp; 

inline int read()
{
	int x = 0, w = 1; char c = getchar();
	while(c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
	while(c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); }
	return x * w;
}

int fpow(int x, int k)
{
	int res = 1;
	while(k)
	{
		if(k & 1) res = 1ll * res * x % p;
		x = 1ll * x * x % p;
		k >>= 1; 
	}
	return res; 
}

int main()
{
	p = reaD(); x = read(); y = read(); m = sqrt(p) + 1; s = y; 
	for(int i = 0; i < m; i++) mp[s] = i, s = 1ll * s * x % p; 
	s = 1; int t = fpow(x, m); 
	for(int i = 1; i <= m + 1; i++)
	{
		s = 1ll * s * t % p; 
		if(mp.count(s))
		{
			printf("%d\n", i * m - mp[s]);
			return 0; 
		}
	}
	puts("no solution"); 
	return 0;
}

ExBSGS算法

好像所有有\(Ex\)的算法似乎都是在不互质的情况下诶

这里, \(ExBSGS\)是用来处理\(gcd(x, p) != 1\)的情况的, 我们可以有这样一个式子, 设\(gcd(x, p) = d\)

注意, 此时若\(d\)不整除\(z\), 方程无解, 若\(d\)整除\(z\), 则有

\[\begin{aligned} \frac{x ^ y}{d} &\equiv \frac{z}{d} \pmod{\frac{p}{d}}\\ \frac{x}{d} * x^{y - 1} &\equiv \frac{z}{d} \pmod{\frac{p}{d}}\\ \end{aligned} \]

注意到\(\frac{x}{d}\)变成了一个系数, 当\(gcd(x, p / d)\)不等于1时不断地除以\(gcd(x, p / d)\), 我们最终可以得到

\[\frac{x ^ k}{d}*x^{y - k} \equiv \frac{z}{d} \pmod{\frac{p}{d}} \]

注意到此时式子中的\(d\)不是第一次算\(gcd\)\(d\)了, 他是所有不为1的\(gcd\)的积

带上个系数跑BSGS, 最后答案加上\(k\)即可

Code

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <vector>
#include <cmath>
#include <map>
#define itn int
#define reaD read
#define LL long long
#define MOD 233333
using namespace std;

int x, y, p, d, m, cnt, sum; 
struct MAP {
	LL ha[MOD+5]; int id[MOD+5];
	void clear() {for (int i = 0; i < MOD; i++) ha[i] = id[i] = -1; }
	int count(LL x) {
		LL pos = x%MOD;
		while (true) {
			if (ha[pos] == -1) return 0;
			if (ha[pos] == x) return 1;
			++pos; if (pos >= MOD) pos -= MOD;
		}
	}
	void insert(LL x, int idex) {
		LL pos = x%MOD; 
		while (true) {
			if (ha[pos] == -1 || ha[pos] == x) { ha[pos] = x, id[pos] = idex; return; }
			++pos; if (pos >= MOD) pos -= MOD; 
		}
	}
	int query(LL x) {
		LL pos = x%MOD;
		while (true) {
			if (ha[pos] == x) return id[pos];
			++pos; if (pos >= MOD) pos -= MOD;
		}
	}
}mp;

int gcd(int n, int m) { return m ? gcd(m, n % m) : n; }

int fpow(int x, int y)
{
	int res = 1;
	while(y)
	{
		if(y & 1) res = 1ll * res * x % p;
		x = 1ll * x * x % p;
		y >>= 1; 
	}
	return res; 
}

int exbsgs(int x, int y, int p)
{
	if(y == 1) return 0;
	cnt = 0; sum = 1; mp.clear();
	while((d = gcd(x, p)) != 1)
	{
		if(y % d) return -1;
		cnt++; p /= d; y /= d; sum = 1ll * sum * (x / d) % p;
		if(y == sum) return cnt; 
	}
	m = sqrt(p) + 1;
	for(int i = 0; i < m; i++) mp.insert(y, i), y = 1ll * y * x % p;
	y = sum; x = fpow(x, m);
	for(int i = 1; i <= m + 1; i++)
	{
		y = 1ll * y * x % p;
		if(mp.count(y)) return i * m - mp.query(y) + cnt; 
	}
	return -1; 
}

int main()
{
	while(scanf("%d%d%d", &x, &p, &y) != EOF)
	{
		if(!x && !p && !y) break; 
		int ans = exbsgs(x, y, p);
		ans == -1 ? puts("No Solution") : printf("%d\n", ans); 
	}
	return 0;
}
posted @ 2019-06-14 18:04  ztlztl  阅读(182)  评论(0编辑  收藏  举报