Loading

(扩展)BSGS

BSGS

概念

BSGS(Baby-step Giant-step),即大步小步法,用于求解关于 \(x\) 的形如 $a^x \equiv n(\bmod p) $ 的高次不定方程的最小非负整数解,其中 \(a, b, p\) 为已经给出的常数且 \(a, p\) 互质

思想

\(x = A\lceil \sqrt{p} \rceil - B\),其中 \(0 \leq A,B \leq \lceil\sqrt{p}\rceil\)

\(a^{A \lceil \sqrt{p} \rceil - B} \equiv n(\bmod p)\)

因为 \(a, p\) 互质,所以在模 \(p\) 意义下进行含 \(a\) 的乘除运算没有影响

\(\therefore a^{A \lceil \sqrt{p} \rceil} \equiv na^B(\bmod p)\)

考虑用哈希表预处理出所有 \(na^B\) 的可能取值以及其对应的 \(B\)

再枚举 \(A\) 的所有取值,假设哈希表中存在 \(a^{i \lceil \sqrt{p} \rceil}\)\(t\) 的映射,则最小非负整数解为 \(i \lceil \sqrt{p} \rceil - t\)

如果遍历 \(A\) 的所有取值都无法找到最小非负整数解,则原方程无解

时间复杂度为 \(O(\sqrt{p})\),用 map 则多 \(log\) 倍常数。

注意 在模 \(p\) 意义下,如果 \(n = 1\)\(p = 1\),则原方程的最小非负整数解为 \(0\);如果 \(a = 0\),若 \(n = 0\),原方程的最小非负整数解为 \(1\),反之原方程无解

注意 map 的时间复杂度为 \(log\) 但不会发生哈希冲突,\(\mathcal{O}(1)\)unordered_map 可能有哈希冲突。卡常可以考虑手写哈希

模板

P3846 [TJOI2007] 可爱的质数/【模板】BSGS

#include <cstdio>
#include <cmath>
#include <map>
using namespace std;

typedef long long ll;

map<ll, ll> mp;

ll fpow(ll a, ll b, ll mod)
{
	ll res = 1;
	while (b)
	{
		if (b & 1)
			res = (res * a) % mod;
		a = (a * a) % mod;
		b >>= 1;
	}
	return res;
}

ll bsgs(ll a, ll n, ll p)
{
	a %= p, n %= p;
	if (n == 1)
		return 0;
	if (a == 0)
		return (n == 0 ? 1 : -1);
	mp.clear();
	ll m = ceil(sqrt(p));
	for (ll i = 0, t = n; i <= m; i++, t = (t * a) % p)
		mp[t] = i;
	for (ll i = 1, pw = fpow(a, m, p), t = pw; i <= m; i++, t = (t * pw) % p)
		if (mp.count(t) && (i * m - mp[t] >= 0))
			return i * m - mp[t];
	return -1;
}

int main()
{
	ll p, b, n, m;
	scanf("%lld%lld%lld", &p, &b, &n);
	ll ans = bsgs(b, n, p);
	if (ans != -1)
		printf("%lld\n", ans);
	else
		puts("no solution");
	return 0;
}

exBSGS

概念

扩展 BSGS(exBSGS),用于求解关于 \(x\) 的形如 \(a^x \equiv n(\bmod p)\) 的高次同余方程的最小非负整数解,其中 \(a, b, p\) 为常数且 \(a, p\) 不一定 互质。

思想

考虑将方程化为某种形式,使得 \(a, p\) 互质。

原方程可以等价地写成 \(a^x + kp = n, k \in \mathbb{Z}\)

\(\gcd(a, p) = d\),根据裴蜀定理知,原方程有解当且仅当 \(d \mid n\),反之原方程无解。

若原方程有解,则其可进一步化为 \(\frac{a^x}{d} + k \frac{p}{d} = \frac{n}{d}\)

\(a^{x - 1} \frac{a}{d} + k \frac{p}{d} = \frac{n}{d}\)

此时 \(a^{x - 1} \rightarrow a^x, \frac{p}{d} \rightarrow p, \frac{n}{d} \rightarrow n\)

递归重复若干次,使得 \(\gcd(a, p) = 1\)

设此时共递归了 \(k\) 次,\(k\) 次递归中的 \(d\) 乘积为 \(d^{\prime}\)

\(a^{x - k} \frac{a^k}{d^{\prime}} \equiv \frac{p}{d^{\prime}}(\bmod \frac{n}{d^{\prime}})\ (1)\)

用 BSGS 求出方程 \(a^{x - k} \equiv \frac{p}{d^{\prime}}(\bmod \frac{n}{d^{\prime}})\) 的最小非负整数解 \(x^{\prime}\)

则此时原方程的最小非负整数解为 \(x^{\prime} + k\)

注意 \((1)\) 式中有系数 \(\frac{a^k}{d^{\prime}}\) 需要乘上,需要作为参数传入 BSGS 函数

详见代码

模板

P4195 【模板】扩展 BSGS/exBSGS

#include <cstdio>
#include <cmath>
#include <map>
using namespace std;

typedef long long ll;

ll a, p, n;
map<ll, ll> mp;

ll gcd(ll a, ll b)
{
	if (b == 0)
		return a;
	return gcd(b, a % b);
}

ll fpow(ll a, ll b, ll mod)
{
	ll res = 1;
	while (b)
	{
		if (b & 1)
			res = (res * a) % mod;
		a = (a * a) % mod;
		b >>= 1;
	}
	return res;
}

ll bsgs(ll a, ll n, ll p, ll ad)
{
	ll m = ceil(sqrt(p));
	for (ll i = 0, t = n; i <= m; i++, t = (t * a) % p)
		mp[t] = i;
	for (ll i = 0, pw = fpow(a, m, p), t = ad; i <= m; i++, t = (t * pw) % p)
		if (mp.count(t))
			if (i * m - mp[t] >= 0)
				return i * m - mp[t];
	return -1;
}

ll exbsgs(ll a, ll n, ll p)
{
	a %= p, n %= p;
	if ((n == 1) || (p == 1))
		return 0;
	if (a == 0)
		return (n == 0 ? 1 : -1);
	ll cnt = 0, ad = 1, d;
	while ((d = gcd(a, p)) != 1)
	{
		if (n % d != 0)
			return -1;
		cnt++;
		n /= d, p /= d;
		ad = (ad * (a / d)) % p;
		if (ad == n)
			return cnt;
	}
	ll ans = bsgs(a, n, p, ad);
	if (ans == -1)
		return -1;
	return ans + cnt;
}

int main()
{
	while (scanf("%lld%lld%lld", &a, &p, &n))
	{
		if (a == 0)
			break;
		mp.clear();
		ll ans = exbsgs(a, n, p);
		if (ans != -1)
			printf("%lld\n", ans);
		else
			puts("No Solution");
	}
	return 0;
}

矩阵 BSGS

对于模数 \(p\) 和矩阵 \(A, B\),求 \(A^x \equiv B \pmod p\) 的最小正整数解。

BSGS 对于矩阵同样成立,考虑手写矩阵乘法。

判断矩阵是否相等可以写哈希。

例题:BZOJ4128 Matrix

#include <cstdio>
#include <cstring>
#include <cmath>
#include <unordered_map>
using namespace std;

#pragma GCC optimize(2)

#define il inline

typedef unsigned long long ull;

const int maxn = 71;
const int base = 998244353;

il int read()
{
    int res = 0;
    char ch = getchar();
    while ((ch < '0') || (ch > '9')) ch = getchar();
    while ((ch >= '0') && (ch <= '9')) res = (res << 3) + (res << 1) + (ch ^ 48), ch = getchar();
    return res;
}

int n = read(), mod = read();

il short cmod(int x, const short& mod) { return (x >= mod ? x - mod : x); }

struct Matrix
{
    int n, a[maxn][maxn];

    Matrix(int _n = 0) : n(_n) { memset(a, 0, sizeof(a)); }
    int* operator [](int k) { return a[k]; }

    friend Matrix operator * (Matrix a, Matrix b)
    {
        Matrix c(a.n);
        for (int i = 1; i <= a.n; i++)
            for (int k = 1; k <= a.n; k++)
                for (int j = 1; j <= a.n; j++)
                    c[i][j] = cmod(c[i][j] + (int)a[i][k] * b[k][j] % mod, mod);
        return c;
    }

    friend Matrix operator == (Matrix a, Matrix b)
    {
        if (a.n != b.n) return false;
        for (int i = 1; i <= a.n; i++)
            for (int j = 1; j <= a.n; j++)
                if (a[i][j] != b[i][j]) return false;
        return true;
    }

    il bool is_eps()
    {
        for (int i = 1; i <= n; i++)
            for (int j = 1; j <= n; j++)
                if (a[i][j] != (i == j)) return false;
        return true;
    }

    il void epsilon()
    {
        memset(a, 0, sizeof(a));
        for (int i = 1; i <= n; i++) a[i][i] = 1;
    }

    il ull hsh()
    {
        ull res = 0;
        for (int i = 1; i <= n; i++)
            for (int j = 1; j <= n; j++)
                res = res * base + a[i][j];
        return res;
    }

    il Matrix cpy()
    {
        Matrix res(n);
        for (int i = 1; i <= n; i++)
            for (int j = 1; j <= n; j++)
                res.a[i][j] = a[i][j];
        return res;
    }
} a, b;

il int bsgs()
{
    if (b.is_eps()) return 0;
    int t = sqrt(mod) + 1;
    Matrix cur = b.cpy(), pwa(a.n); pwa.epsilon();
    unordered_map<ull, int> vis;
    for (int i = 0; i < t; i++) vis[cur.hsh()] = i, cur = cur * a, pwa = pwa * a;
    a = cur = pwa;
    for (int i = 1; i <= t; i++)
    {
        ull val = cur.hsh();
        if (vis.count(val)) return i * t - vis[val];
        cur = cur * a;
    }
    return -1;
}

int main()
{
    // printf("%d %d\n", n, mod);
    a.n = b.n = n;
    for (int i = 1; i <= n; i++)
        for (int j = 1; j <= n; j++)
            a[i][j] = read();
    for (int i = 1; i <= n; i++)
        for (int j = 1; j <= n; j++)
            b[i][j] = read();
    printf("%d\n", bsgs());
    return 0;
}
posted @ 2021-12-03 21:28  kymru  阅读(105)  评论(0编辑  收藏  举报