[容斥+NTT+启发式合并] ICPC2021上海站B题 Strange_Permutations

比赛时推出来了没写,血亏,赛后补上。



#include <bits/stdc++.h>
using namespace std;

#define LL long long
const int maxn = 2100000;

LL qpow(LL b, LL n, LL MOD) {
    if (MOD == 1) return 0;
    LL x = 1, Power = b % MOD;
    while (n) {
        if (n & 1) x = x * Power % MOD;
        Power = Power * Power % MOD;
        n >>= 1;
    }
    return x;
}

const LL P = 998244353, G = 3, Gi = 332748118;

namespace Poly {
    int r[maxn];
    int L, limit;

    LL pinv(LL x) { return qpow(x, P - 2, P); }

    //快速数论变换 type=1:正变换 type=-1:逆变换
    void NTT(LL* A, int type) {
        for (int i = 0; i < limit; i++)
            if (i < r[i]) swap(A[i], A[r[i]]);
        for (int mid = 1; mid < limit; mid <<= 1) {
            LL Wn = qpow(type == 1 ? G : Gi, (P - 1) / (mid << 1), P);
            for (int j = 0; j < limit; j += (mid << 1)) {
                LL w = 1;
                for (int k = 0; k < mid; k++, w = (w * Wn) % P) {
                    int x = A[j + k], y = w * A[j + k + mid] % P;
                    A[j + k] = (x + y) % P;
                    A[j + k + mid] = (x - y + P) % P;
                }
            }
        }
        if (type == 1) return;
        LL inv_limit = pinv(limit);
        for (int i = 0; i < limit; ++i)
            A[i] = A[i] * inv_limit % P;
    }

    //多项式卷积 a(x): N-1次多项式 b(x): M-1次多项式
    void Conv(LL* a, int N, LL* b, LL M, LL* c) {
        L = 0; limit = 1;
        while (limit <= N + M) limit <<= 1, L++;
        for (int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
        NTT(a, 1); NTT(b, 1);
        for (int i = 0; i < limit; i++) c[i] = a[i] * b[i] % P;
        NTT(c, -1);
    }
}

struct node { int len, id; };
struct cmp { bool operator()(const node& a, const node& b) { return a.len > b.len; } };
priority_queue<node, vector<node>, cmp> Q;
vector<LL> vec[100010];
bool vis[100010];
int nxt[100010], num[100010];
int inv[100010], fact[100010], finv[100010];
LL a[maxn], b[maxn];
int n, m;

void Init() {
    inv[1] = fact[0] = fact[1] = finv[0] = finv[1] = 1;
    for (int i = 2;i <= 100000;++i) {
        inv[i] = ((-1LL * (P / i) * inv[P % i]) % P + P) % P;
        fact[i] = 1LL * fact[i - 1] * i % P;
        finv[i] = 1LL * finv[i - 1] * inv[i] % P;
    }
}

LL C(LL n, LL m) {
    if (m<0 || m>n) return 0;
    return 1LL * fact[n] * finv[m] % P * finv[n - m] % P;
}

void Convolution(int u, int v) {
    int n = vec[u].size(), m = vec[v].size();
    int limit = 1;while (limit <= n + m) limit <<= 1;
    fill(a, a + limit, 0);
    fill(b, b + limit, 0);
    for (int i = 0;i < n;++i) a[i] = vec[u][i];
    for (int i = 0;i < m;++i) b[i] = vec[v][i];
    Poly::Conv(a, n, b, m, a);
    vec[u].resize(n + m - 1);
    for (int i = 0;i < n + m - 1;++i)
        vec[u][i] = a[i];
}

LL solve() {
    while (Q.size() > 1) {
        int u = Q.top().id; Q.pop();
        int v = Q.top().id; Q.pop();
        Convolution(u, v);
        Q.push((node) { (int)vec[u].size(), u });
    }
    vector<LL>& g = vec[Q.top().id];
    LL ans = 0;
    for (LL i = 0;i < g.size();++i)
        ans = (ans + ((i & 1) ? -1LL : 1LL) * fact[n - i] * g[i] % P) % P;
    ans = (ans % P + P) % P;
    return ans;
}

int main() {
    Init();
    scanf("%d", &n);
    for (int i = 1;i <= n;++i)
        scanf("%d", &nxt[i]);
    for (int i = 1;i <= n;++i) {
        if (vis[i]) continue;
        int u = i; num[++m] = 1;
        vis[u] = true;
        while (!vis[nxt[u]]) { u = nxt[u]; vis[u] = true; ++num[m]; }
    }
    for (int i = 1;i <= m;++i) {
        vec[i].resize(num[i] + 1);
        for (int j = 0;j <= num[i];++j)
            vec[i][j] = C(num[i], j);
        vec[i][num[i]] = (vec[i][num[i]] + P - 1) % P;
        Q.push((node) { (int)vec[i].size(), i });
    }
    printf("%lld\n", solve());

    return 0;
}
posted @ 2021-11-28 22:25  AE酱  阅读(278)  评论(0编辑  收藏  举报