[AGC005D] ~K Perm Counting

题意

求对于所有的 \(i\) 满足 \(|P_i - i| \neq k\),的排列数量,对 \(924844033\) 取模。

\(2 \le n \le 2 \times 10 ^ 3, 1 \le k \le n - 1\)

Sol

考虑转成 \(n \times n\) 的网格图,那么就是所有 \((i, i + k)\) 以及 \((i, i - k)\) 的格子涂黑不能用。

题意转化为在网格图里放 \(n\) 个车,不能放在黑格子里,互相不攻击的方案数。

考虑尝试把这个限制容斥掉,设 \(f(x)\) 表示我们钦定在黑色格子里放 \(x\) 个车的方案数。

那么答案就是:

\[\sum \sum_{i = 0} ^ n (-1) ^ i f(i) (n - i)! \]

如何计算 \(f\)?注意到我们将每行每列的黑格子连起来,然后就变成了若干条链,不能选择链上相邻的点,直接在链上 \(\texttt{dp{\) 即可。

Code

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>
#include <bitset>
#include <vector>
using namespace std;
#ifdef ONLINE_JUDGE

#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;

#endif
int read() {
    int p = 0, flg = 1;
    char c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-') flg = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        p = p * 10 + c - '0';
        c = getchar();
    }
    return p * flg;
}
void write(int x) {
    if (x < 0) {
        x = -x;
        putchar('-');
    }
    if (x > 9) {
        write(x / 10);
    }
    putchar(x % 10 + '0');
}
bool _stmer;

const int N = 4e3 + 5, M = 8e6 + 5, mod = 924844033;

array <array <array <int, 2>, N>, N> f;

void Mod(int &x) {
    if (x >= mod) x -= mod;
    if (x < 0) x += mod;
}

namespace Uni {

array <int, M> fa;

int find(int x) {
    if (x == fa[x]) return x;
    return fa[x] = find(fa[x]);
}

void merge(int x, int y) {
    int fx = find(x),
        fy = find(y);
    if (fx == fy) return;
    fa[fy] = fx;
}

void init(int n) {
    for (int i = 1; i <= n; i++)
        fa[i] = i;
}

} //namespace Uni

int getid(int x, int y) { return (x - 1) * ((int)4e3) + y; }

array <bitset <N>, N> vis;
array <int, M> siz;

bool _edmer;
int main() {
    cerr << (&_stmer - &_edmer) / 1024.0 / 1024.0 << "MB\n";
    int n = read(), k = read();
    Uni::init(8e6);
    for (int i = 1; i <= n; i++) {
        if (i + k <= n) vis[i][i + k] = 1;
        if (i - k > 0) vis[i][i - k] = 1;
    }
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= n; j++) {
            if (!vis[i][j]) continue;
            if (i + 2 * k <= n && vis[i + 2 * k][j])
                Uni::merge(getid(i, j), getid(i + 2 * k, j));
            if (j + 2 * k <= n && vis[i][j + 2 * k])
                Uni::merge(getid(i, j), getid(i, j + 2 * k));
        }
    }
    for (int i = 1; i <= n; i++)
        for (int j = 1; j <= n; j++)
            if (vis[i][j])
                siz[Uni::find(getid(i, j))]++;
    vector <int> isl;
    for (int i = 1; i <= 8e6; i++)
        if (siz[i]) isl.push_back(siz[i]);
#define upd(x, y) (x += y, Mod(x))
    f[0][0][0] = 1;
    int lst = 0;
    for (auto p : isl) {
        for (int i = lst + 1; i <= lst + p; i++) {
            for (int j = n; ~j; j--) {
                upd(f[i][j][0], f[i - 1][j][0]);
                upd(f[i][j][0], f[i - 1][j][1]);
                if (j)
                    upd(f[i][j][1], f[i - 1][j - 1][0]);
            }
        }
        lst += p;
        for (int i = 1; i <= n; i++)
            upd(f[lst][i][0], f[lst][i][1]), f[lst][i][1] = 0;
    }
    int ans = 0, res = 1;
    for (int i = n; ~i; i--) {
        ans += ((i & 1) ? -1ll : 1ll) * (f[lst][i][0] + f[lst][i][1]) * res % mod, Mod(ans);
        res = 1ll * res * (n - i + 1) % mod;
    }
    write(ans), puts("");
    return 0;
}
posted @ 2024-11-11 19:17  cxqghzj  阅读(3)  评论(0编辑  收藏  举报