bzoj1420/1319 Discrete Root
传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=1420
http://www.lydsy.com/JudgeOnline/problem.php?id=1319
【题解】
求x^A=B(mod P),其中P是质数。
考虑对两边取log,设g为P的原根。
Alog(x) = log(B) (mod P-1)
log(x)表示以g为底的log
那么log(B) = y,其中g^y = B (mod P),用BSGS求出即可。
我们要求的是x,不妨先求log(x),设ans=log(x)
那么A*ans + (P-1)*k = y。这是一个exgcd的形式,所以我们可以求出ans的所有解(由于相当于指数,所以必须小于P-1)
然后快速幂即可。
# include <map> # include <math.h> # include <stdio.h> # include <assert.h> # include <string.h> # include <iostream> # include <algorithm> // # include <bits/stdc++.h> using namespace std; typedef long long ll; typedef long double ld; typedef unsigned long long ull; const int M = 5e5 + 10; const int mod = 1e9+7; # define RG register # define ST static ll A, B, P, B0; ll g, ans[M]; int ansn=0; inline ll pwr(ll a, ll b, ll P) { ll ret = 1; a %= P; while(b) { if(b&1) ret = ret * a % P; a = a * a % P; b >>= 1; } return ret; } ll y[M]; inline ll G(ll x) { ll t = x; int nn = 0; for (int i=2; i*i<=x; ++i) { if(x%i) continue; y[++nn] = i; while(x%i == 0) x/=i; } if(x != 1) y[++nn] = x; for (ll g=2; ; ++g) { bool flag = 1; for (int i=1; i<=nn; ++i) if(pwr(g, t/y[i], P) == 1) { flag = 0; break; } if(flag) return g; } } map<ll, int> mp; inline ll BSGS(ll A, ll B, ll P) { mp.clear(); int m = ceil(sqrt(1.0 * P)); ll t = B, g; for (int i=0; i<m; ++i) { if(!mp[t]) mp[t] = i; t = t * A % P; } g = pwr(A, m, P); t = g; for (int i=1, ps; i<=m+1; ++i) { if(mp.count(t)) return (ll)i*m - mp[t]; t = t * g % P; } return -1; } ll exgcd(ll a, ll b, ll &x, ll &y) { if(b == 0) { x = 1, y = 0; return a; } ll ret = exgcd(b, a%b, x, y), t; t = x; x = y; y = t - (a/b) * y; return ret; } int main() { cin >> P >> A >> B; g = G(P-1); // cout << g << endl; // x^A = B (mod P) // A log_g(x) = log_g(B) (mod P-1) B0 = BSGS(g, B, P); assert(B0 != -1); // cout << B0 << endl; ll tx, ty, GCD; GCD = exgcd(A, P-1, tx, ty); if(B0 % GCD) { puts("0"); return 0; } ty = (P-1)/GCD; tx = (tx % ty + ty) % ty; tx = (tx * B0/GCD) % ty; while(tx < P-1) { ans[++ansn] = pwr(g, tx, P); tx += ty; } sort(ans+1, ans+ansn+1); cout << ansn << endl; for (int i=1; i<=ansn; ++i) printf("%lld\n", ans[i]); return 0; }