Loading

P8338 [AHOI2022] 排列

建边:\(i \to p_i\),这样会形成若干个置换环,每次操作相当于每个点同时走一步。

记置换环的数量为 \(m\),从 \(1\)\(m\) 编号,第 \(i\) 个置换环的大小是 \(s_i\)\(bel_i\) 为点 \(i\) 所属的置换环编号。

显然 \(f(i, j) = 0\) 的充要条件是 \(i, j\) 在同一置换环上,否则 \(f(i, j) = \operatorname{lcm}\{\{s_k | k \ne bel_i, bel_j\} + \{s_{bel_i} + s_{bel_j}\}\}\)

暴力枚举两个合并的置换环统计答案的时间复杂度是 \(\mathcal O(m^3 \log n)\),考虑优化。

其实我们并不关心合并的是哪两个置换环,我们关心的仅是它们俩的 \(s\)

\(S = \{s_i | 1 \le i \le m \}\),又 \(\sum\limits_{i = 1}^m s_i = n\),故 \(|S| \le \sqrt n\),暴力枚举 \(S\) 里的任意两个的时间复杂度是 \(\mathcal O(n)\) 的。

问题来到如何快速求 \(\rm lcm\),存每个质因数的前 \(3\) 大指数,然后动态维护就行,时间复杂度 \(\mathcal O(n \log n)\)

代码:

#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

constexpr int N = 5e5 + 10, MOD = 1e9 + 7;

int n, m, a[N], cir[N], cnt[N];

namespace DSU {
    int fa[N], sz[N];
    void init(int n) {for (int i = 1; i <= n; i++) fa[i] = i, sz[i] = 1;}
    int find(int x) {return x == fa[x] ? x : fa[x] = find(fa[x]);}
    void merge(int x, int y) {
        int fx = find(x), fy = find(y);
        if (fx != fy) {
            if (sz[fx] < sz[fy]) swap(fx, fy);
            fa[fy] = fx, sz[fx] += sz[fy];
        }
    }
}

ll inv(ll base, int e = MOD - 2) {
    ll res = 1;
    while (e) {
        if (e & 1) res = res * base % MOD;
        base = base * base % MOD;
        e >>= 1;
    }
    return res;
}

int mx[N][3]; ll lcm = 1;
inline void fix(int mx[], int p) {
    if (p > mx[0]) lcm = lcm * (p / mx[0]) % MOD, mx[2] = mx[1], mx[1] = mx[0], mx[0] = p;
    else if (p > mx[1]) mx[2] = mx[1], mx[1] = p;
    else mx[2] = max(mx[2], p);
}
inline void siu(int mx[], int p) {
    if (p == mx[2]) mx[2] = 1;
    else if (p == mx[1]) mx[1] = mx[2], mx[2] = 1;
    else if (p == mx[0]) lcm = lcm * inv(mx[0]) % MOD * mx[1] % MOD, mx[0] = mx[1], mx[1] = mx[2], mx[2] = 1;
}
void add(int x) {
    for (int i = 2; i * i <= x; i++) if (x % i == 0) {
        int p = 1;
        while (x % i == 0) x /= i, p *= i;
        fix(mx[i], p);
    }
    if (x > 1) fix(mx[x], x);
}
void del(int x) {
    for (int i = 2; i * i <= x; i++) if (x % i == 0) {
        int p = 1;
        while (x % i == 0) x /= i, p *= i;
        siu(mx[i], p);
    }
    if (x > 1) siu(mx[x], x);
}

void solve() {
    cin >> n; DSU::init(n);
    for (int i = 1; i <= n; i++) cin >> a[i], DSU::merge(i, a[i]), mx[i][0] = mx[i][1] = mx[i][2] = 1;
    m = 0; memset(cnt, 0, sizeof(cnt)); lcm = 1;
    for (int i = 1; i <= n; i++) if (DSU::find(i) == i) {
        add(DSU::sz[i]);
        if(!cnt[DSU::sz[i]]++) cir[++m] = DSU::sz[i];
    }
    ll ans = 0;
    for (int i = 1; i <= m; i++) {
        del(cir[i]);
        if (cnt[cir[i]] > 1) {
            ll cur = lcm; int p2 = __builtin_ctz(cir[i]) + 1;
            if ((1 << p2) > mx[2][0]) cur <<= 1;
            ans = (ans + cur * (cir[i] * cnt[cir[i]]) % MOD * (cir[i] * (cnt[cir[i]] - 1))) % MOD;
        }
        for (int j = i + 1; j <= m; j++) {
            del(cir[j]);
            int now = cir[i] + cir[j]; ll cur = lcm;
            for (int k = 2; k * k <= now; k++) if (now % k == 0) {
                int p = 1;
                while (now % k == 0) now /= k, p *= k;
                if (p > mx[k][0]) cur = cur * (p / mx[k][0]) % MOD;
            }
            if (now > mx[now][0]) cur = cur * now % MOD;
            ans = (ans + 2 * cur * (cir[i] * cnt[cir[i]]) % MOD * (cir[j] * cnt[cir[j]])) % MOD;
            add(cir[j]);
        }
        add(cir[i]);
    }
    cout << ans << '\n';
}

int main() {
    ios_base::sync_with_stdio(0); cin.tie(nullptr), cout.tie(nullptr);
    int t; cin >> t;
    while (t--) solve();
    return 0;
}
posted @ 2024-02-03 16:18  Chy12321  阅读(10)  评论(0编辑  收藏  举报