[AGC005D] ~K Perm Counting
题意
求对于所有的 \(i\) 满足 \(|P_i - i| \neq k\),的排列数量,对 \(924844033\) 取模。
\(2 \le n \le 2 \times 10 ^ 3, 1 \le k \le n - 1\)。
Sol
考虑转成 \(n \times n\) 的网格图,那么就是所有 \((i, i + k)\) 以及 \((i, i - k)\) 的格子涂黑不能用。
题意转化为在网格图里放 \(n\) 个车,不能放在黑格子里,互相不攻击的方案数。
考虑尝试把这个限制容斥掉,设 \(f(x)\) 表示我们钦定在黑色格子里放 \(x\) 个车的方案数。
那么答案就是:
\[\sum \sum_{i = 0} ^ n (-1) ^ i f(i) (n - i)!
\]
如何计算 \(f\)?注意到我们将每行每列的黑格子连起来,然后就变成了若干条链,不能选择链上相邻的点,直接在链上 \(\texttt{dp{\) 即可。
Code
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>
#include <bitset>
#include <vector>
using namespace std;
#ifdef ONLINE_JUDGE
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf;
#endif
int read() {
int p = 0, flg = 1;
char c = getchar();
while (c < '0' || c > '9') {
if (c == '-') flg = -1;
c = getchar();
}
while (c >= '0' && c <= '9') {
p = p * 10 + c - '0';
c = getchar();
}
return p * flg;
}
void write(int x) {
if (x < 0) {
x = -x;
putchar('-');
}
if (x > 9) {
write(x / 10);
}
putchar(x % 10 + '0');
}
bool _stmer;
const int N = 4e3 + 5, M = 8e6 + 5, mod = 924844033;
array <array <array <int, 2>, N>, N> f;
void Mod(int &x) {
if (x >= mod) x -= mod;
if (x < 0) x += mod;
}
namespace Uni {
array <int, M> fa;
int find(int x) {
if (x == fa[x]) return x;
return fa[x] = find(fa[x]);
}
void merge(int x, int y) {
int fx = find(x),
fy = find(y);
if (fx == fy) return;
fa[fy] = fx;
}
void init(int n) {
for (int i = 1; i <= n; i++)
fa[i] = i;
}
} //namespace Uni
int getid(int x, int y) { return (x - 1) * ((int)4e3) + y; }
array <bitset <N>, N> vis;
array <int, M> siz;
bool _edmer;
int main() {
cerr << (&_stmer - &_edmer) / 1024.0 / 1024.0 << "MB\n";
int n = read(), k = read();
Uni::init(8e6);
for (int i = 1; i <= n; i++) {
if (i + k <= n) vis[i][i + k] = 1;
if (i - k > 0) vis[i][i - k] = 1;
}
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
if (!vis[i][j]) continue;
if (i + 2 * k <= n && vis[i + 2 * k][j])
Uni::merge(getid(i, j), getid(i + 2 * k, j));
if (j + 2 * k <= n && vis[i][j + 2 * k])
Uni::merge(getid(i, j), getid(i, j + 2 * k));
}
}
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
if (vis[i][j])
siz[Uni::find(getid(i, j))]++;
vector <int> isl;
for (int i = 1; i <= 8e6; i++)
if (siz[i]) isl.push_back(siz[i]);
#define upd(x, y) (x += y, Mod(x))
f[0][0][0] = 1;
int lst = 0;
for (auto p : isl) {
for (int i = lst + 1; i <= lst + p; i++) {
for (int j = n; ~j; j--) {
upd(f[i][j][0], f[i - 1][j][0]);
upd(f[i][j][0], f[i - 1][j][1]);
if (j)
upd(f[i][j][1], f[i - 1][j - 1][0]);
}
}
lst += p;
for (int i = 1; i <= n; i++)
upd(f[lst][i][0], f[lst][i][1]), f[lst][i][1] = 0;
}
int ans = 0, res = 1;
for (int i = n; ~i; i--) {
ans += ((i & 1) ? -1ll : 1ll) * (f[lst][i][0] + f[lst][i][1]) * res % mod, Mod(ans);
res = 1ll * res * (n - i + 1) % mod;
}
write(ans), puts("");
return 0;
}