AtCoder Beginner Contest 335 G Discrete Logarithm Problems
考虑若我们对于每个 \(a_i\) 求出来了使得 \(g^{b_i} \equiv a_i \pmod P\) 的 \(b_i\)(其中 \(g\) 为 \(P\) 的原根),那么 \(a_i^k \equiv a_j \pmod P\) 等价于 \(kb_i \equiv b_j \pmod{P - 1}\),有解的充要条件是 \(\gcd(b_i, P - 1) \mid b_j\)。
显然我们不可能对于每个 \(a_i\) 都求出来 \(b_i\)。注意到我们只关心 \(c_i = \gcd(b_i, P - 1)\),而 \(c_i\) 为满足 \(a_i^{c_i} \equiv 1 \pmod P\) 的最小正整数。若求出 \(c_i\) 则等价于统计 \(c_i \mid c_j\) 的对数。于是问题变成求出 \(c_i\)。
因为我们一定有 \(a_i^{P - 1} \equiv 1 \pmod P\),所以 \(c_i\) 一定为 \(P - 1\) 的因数。所以我们初始令 \(c_i = P - 1\),然后对 \(P - 1\) 分解质因数,依次让 \(c_i\) 试除 \(P - 1\) 的每个质因子,判断除完后是否还有 \(a_i^{c_i} \equiv 1 \pmod P\) 即可。这部分复杂度大概是 \(O(n \log^2 P)\) 的。
问题还剩下统计 \(c_i \mid c_j\) 的对数。因为 \(c_i\) 为 \(P - 1\) 的因数,所以我们可以做一遍 Dirichlet 后缀和求出 \(f_x\) 表示满足 \(x \mid c_i\) 的 \(i\) 的个数。最后遍历 \(c_i\) 统计即可。
总时间复杂度大概是 \(O(n \log^2 P + m \log m \log P)\),其中 \(m\) 为 \(P - 1\) 因数个数。
code
// Problem: G - Discrete Logarithm Problems
// Contest: AtCoder - AtCoder Beginner Contest 335 (Sponsored by Mynavi)
// URL: https://atcoder.jp/contests/abc335/tasks/abc335_g
// Memory Limit: 1024 MB
// Time Limit: 5000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef __int128 lll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 200100;
ll n, m, a[maxn], tot, T, c[maxn], tt, f[maxn];
pii b[maxn];
inline ll qpow(ll b, ll p, const ll &mod) {
ll res = 1;
while (p) {
if (p & 1) {
res = (lll)res * b % mod;
}
b = (lll)b * b % mod;
p >>= 1;
}
return res;
}
void solve() {
scanf("%lld%lld", &n, &m);
ll x = m - 1;
for (ll i = 1; i * i <= x; ++i) {
if (x % i) {
continue;
}
c[++tt] = i;
if (i * i != x) {
c[++tt] = x / i;
}
}
sort(c + 1, c + tt + 1);
for (ll i = 2; i * i <= x; ++i) {
if (x % i == 0) {
ll cnt = 0;
while (x % i == 0) {
x /= i;
++cnt;
}
b[++tot] = mkp(i, cnt);
}
}
if (x > 1) {
b[++tot] = mkp(x, 1);
}
for (int i = 1; i <= n; ++i) {
scanf("%lld", &a[i]);
ll x = m - 1;
for (int j = 1; j <= tot; ++j) {
for (int _ = 0; _ < b[j].scd; ++_) {
if (qpow(a[i], x / b[j].fst, m) == 1) {
x /= b[j].fst;
}
}
}
a[i] = x;
++f[lower_bound(c + 1, c + tt + 1, a[i]) - c];
}
for (int i = 1; i <= tot; ++i) {
for (int j = tt; j; --j) {
if (c[j] % b[i].fst) {
continue;
}
ll x = lower_bound(c + 1, c + tt + 1, c[j] / b[i].fst) - c;
f[x] += f[j];
}
}
ll ans = 0;
for (int i = 1; i <= n; ++i) {
ans += f[lower_bound(c + 1, c + tt + 1, a[i]) - c];
}
printf("%lld\n", ans);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}