P8338 [AHOI2022] 排列
建边:\(i \to p_i\),这样会形成若干个置换环,每次操作相当于每个点同时走一步。
记置换环的数量为 \(m\),从 \(1\) 到 \(m\) 编号,第 \(i\) 个置换环的大小是 \(s_i\),\(bel_i\) 为点 \(i\) 所属的置换环编号。
显然 \(f(i, j) = 0\) 的充要条件是 \(i, j\) 在同一置换环上,否则 \(f(i, j) = \operatorname{lcm}\{\{s_k | k \ne bel_i, bel_j\} + \{s_{bel_i} + s_{bel_j}\}\}\)。
暴力枚举两个合并的置换环统计答案的时间复杂度是 \(\mathcal O(m^3 \log n)\),考虑优化。
其实我们并不关心合并的是哪两个置换环,我们关心的仅是它们俩的 \(s\)。
令 \(S = \{s_i | 1 \le i \le m \}\),又 \(\sum\limits_{i = 1}^m s_i = n\),故 \(|S| \le \sqrt n\),暴力枚举 \(S\) 里的任意两个的时间复杂度是 \(\mathcal O(n)\) 的。
问题来到如何快速求 \(\rm lcm\),存每个质因数的前 \(3\) 大指数,然后动态维护就行,时间复杂度 \(\mathcal O(n \log n)\)。
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
constexpr int N = 5e5 + 10, MOD = 1e9 + 7;
int n, m, a[N], cir[N], cnt[N];
namespace DSU {
int fa[N], sz[N];
void init(int n) {for (int i = 1; i <= n; i++) fa[i] = i, sz[i] = 1;}
int find(int x) {return x == fa[x] ? x : fa[x] = find(fa[x]);}
void merge(int x, int y) {
int fx = find(x), fy = find(y);
if (fx != fy) {
if (sz[fx] < sz[fy]) swap(fx, fy);
fa[fy] = fx, sz[fx] += sz[fy];
}
}
}
ll inv(ll base, int e = MOD - 2) {
ll res = 1;
while (e) {
if (e & 1) res = res * base % MOD;
base = base * base % MOD;
e >>= 1;
}
return res;
}
int mx[N][3]; ll lcm = 1;
inline void fix(int mx[], int p) {
if (p > mx[0]) lcm = lcm * (p / mx[0]) % MOD, mx[2] = mx[1], mx[1] = mx[0], mx[0] = p;
else if (p > mx[1]) mx[2] = mx[1], mx[1] = p;
else mx[2] = max(mx[2], p);
}
inline void siu(int mx[], int p) {
if (p == mx[2]) mx[2] = 1;
else if (p == mx[1]) mx[1] = mx[2], mx[2] = 1;
else if (p == mx[0]) lcm = lcm * inv(mx[0]) % MOD * mx[1] % MOD, mx[0] = mx[1], mx[1] = mx[2], mx[2] = 1;
}
void add(int x) {
for (int i = 2; i * i <= x; i++) if (x % i == 0) {
int p = 1;
while (x % i == 0) x /= i, p *= i;
fix(mx[i], p);
}
if (x > 1) fix(mx[x], x);
}
void del(int x) {
for (int i = 2; i * i <= x; i++) if (x % i == 0) {
int p = 1;
while (x % i == 0) x /= i, p *= i;
siu(mx[i], p);
}
if (x > 1) siu(mx[x], x);
}
void solve() {
cin >> n; DSU::init(n);
for (int i = 1; i <= n; i++) cin >> a[i], DSU::merge(i, a[i]), mx[i][0] = mx[i][1] = mx[i][2] = 1;
m = 0; memset(cnt, 0, sizeof(cnt)); lcm = 1;
for (int i = 1; i <= n; i++) if (DSU::find(i) == i) {
add(DSU::sz[i]);
if(!cnt[DSU::sz[i]]++) cir[++m] = DSU::sz[i];
}
ll ans = 0;
for (int i = 1; i <= m; i++) {
del(cir[i]);
if (cnt[cir[i]] > 1) {
ll cur = lcm; int p2 = __builtin_ctz(cir[i]) + 1;
if ((1 << p2) > mx[2][0]) cur <<= 1;
ans = (ans + cur * (cir[i] * cnt[cir[i]]) % MOD * (cir[i] * (cnt[cir[i]] - 1))) % MOD;
}
for (int j = i + 1; j <= m; j++) {
del(cir[j]);
int now = cir[i] + cir[j]; ll cur = lcm;
for (int k = 2; k * k <= now; k++) if (now % k == 0) {
int p = 1;
while (now % k == 0) now /= k, p *= k;
if (p > mx[k][0]) cur = cur * (p / mx[k][0]) % MOD;
}
if (now > mx[now][0]) cur = cur * now % MOD;
ans = (ans + 2 * cur * (cir[i] * cnt[cir[i]]) % MOD * (cir[j] * cnt[cir[j]])) % MOD;
add(cir[j]);
}
add(cir[i]);
}
cout << ans << '\n';
}
int main() {
ios_base::sync_with_stdio(0); cin.tie(nullptr), cout.tie(nullptr);
int t; cin >> t;
while (t--) solve();
return 0;
}