2024 BUPT Programming Contest F
简要题意
多测,给定一个 \(n \times n\) 矩阵,矩阵中的每一个元素的计算方式如下:
- 矩阵的行和列唯一决定两个整数对 \((a, b)\),矩阵第 \(a(0 \le a < n)\) 行第 \(b(0 \le b < n)\) 列的元素为 \(a \times b \bmod n\)
求矩阵中元素 \(m\) 出现的次数。
-
\(0 \le m < n \le 10^{12}\)
-
\(1 \le T \le 100\)
题解
对于矩阵中的任意一个元素是独立的,因此我们考虑对于一组 \(a \times b \equiv m \pmod n\) 的合法性。
原式可推出 \(ab + kn = m\),由裴蜀定理可知,当 \(\gcd(a, n) \mid m\) 时,方程有线性整数解。
接下来考虑对于一个合法的 \(a\) 有多少组解是可以被接受的,由于 \(at = sn\),我们想要找到尽可能小的 \((s, t)\) 的一组解,那么 \(t = \frac{n}{\gcd(a, n)}\),若 \(ab \equiv m \pmod n\),那么 \(a\left(b + \frac{n}{\gcd(a, n)}\right) \equiv m \pmod n\)。
对于最小的二元组 \((a, b)\) 使原式成立,必有 \(b < \frac{n}{\gcd(a, n)}\) 成立,若有 \(b \ge \frac{n}{\gcd(a, n)}\),则由 \(a\left(b + \frac{n}{\gcd(a, n)}\right) \equiv m \pmod n\) 可递归定义最小的 \(b\)。
因此,对于任何合法的 \(a\),必存在最小的二元组 \((a, b)\) 使得同余式成立,因此所有的合法解为:
\[\begin{aligned}
\sum_{a = 0}^{n - 1}\left\lfloor\frac{n}{\frac{n}{\gcd(a, n)}}\right\rfloor[\gcd(a, n) \mid m] &= \sum_{a = 0}^{n - 1}\gcd(a, n)[\gcd(a, n) \mid m] \\
&= \sum_{a = 1}^{n}\gcd(a, n)[\gcd(a, n) \mid m] \\
&= \sum_{d \mid m}d\sum_{a = 1}^{n}[\gcd(a, n) = d] \\
&= \sum_{d \mid m, d \mid n}d\sum_{a = 1}^{n}[\gcd(a, n) = d] \\
&= \sum_{d \mid \gcd(n, m)}d\sum_{a = 1}^{n}\left[\gcd\left(\frac{a}{d}, \frac{n}{d}\right) = 1\right] \\
&= \sum_{d \mid \gcd(n, m)}d\sum_{a = 1}^{\frac{n}{d}}\left[\gcd\left(a, \frac{n}{d}\right) = 1\right] \\
&= \sum_{d \mid \gcd(n, m)}d\varphi\left(\frac{n}{d}\right)
\end{aligned}
\]
参考代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e6 + 10, mod = 998244353;
ll n, m;
ll primes[N], phi[N], cnt;
bool st[N];
void init()
{
phi[1] = 1;
for (int i = 2; i < N; i ++ )
{
if (!st[i]) primes[cnt ++ ] = i, phi[i] = i - 1;
for (int j = 0; primes[j] * i < N; j ++ )
{
st[primes[j] * i] = 1;
if (i % primes[j] == 0)
{
phi[primes[j] * i] = phi[i] * primes[j];
break;
}
phi[primes[j] * i] = phi[i] * (primes[j] - 1);
}
}
}
ll get_phi(ll x)
{
if (x < N) return phi[x];
ll res = 1;
for (int i = 0; primes[i] * primes[i] <= x; i ++ )
{
if (x % primes[i] == 0) res = res * (primes[i] - 1) % mod, x /= primes[i];
while (x % primes[i] == 0) res = res * primes[i] % mod, x /= primes[i];
}
if (x > 1) res = res * (x - 1) % mod;
return res;
}
void solve()
{
cin >> m >> n;
ll g = __gcd(m, n), ans = 0;
for (ll i = 1; i * i <= g; i ++ )
if (g % i == 0)
{
(ans += i * get_phi(n / i) % mod) %= mod;
if (i * i != g) (ans += (g / i) % mod * get_phi(n / (g / i)) % mod) %= mod;
}
cout << ans << endl;
}
int main()
{
int T;
init();
cin >> T;
while (T -- ) solve();
return 0;
}