CodeForces 715E Complete the Permutations
最小交换次数等于 \(n - \text{环数}\)。所以题目要我们统计把 \(p, q\) 补全成排列,连边 \(p_i \to q_i\),环数 \(= i\) 的方案数。
考虑把边根据 \(p_i, q_i\) 的是否已知状态分成四类:
- \(p \to q\)
- \(p \to 0\)
- \(0 \to q\)
- \(0 \to 0\)
注意若存在 \(p \to 0, 0 \to q\) 且 \(p = q\),我们把它们合并成一个 \(4\) 类边。
对于 \(1\) 类边直接把它缩成一个点,记录一下是否形成环即可。
对于剩下的边,设 \(2\) 类边数量为 \(n_1\),\(3\) 类边数量为 \(n_2\),\(4\) 类边数量为 \(n_3\)。
对于 \(2\) 类边,我们考虑钦定一些边形成环,剩下的边接到 \(4\) 类边去,因为 \(0 \to 0\) 和 \(p \to 0\) 合并还是一条 \(0 \to 0\)。
设 \(f_i\) 为 \(2\) 类边形成 \(i\) 个环的方案数。枚举钦定了 \(j\) 条边形成环,有:
最后乘的下降幂意义是,一条边 \(p_1 \to 0\) 可以和之后的还没处理的边 \(p_2 \to 0\) 合并变成 \(p_1 \to 0\),也可以直接和一条 \(0 \to 0\) 边合并。
需要特判 \(n_3 = 0\),此时 \(f_i = \begin{bmatrix} n_1 \\ i \end{bmatrix}\)。
再设 \(g_i\) 为 \(3\) 类边形成 \(i\) 个环的方案数。计算方法和上面一样,把 \(n_1\) 改成 \(n_2\) 即可。
最后设 \(h_i\) 为 \(4\) 类边形成 \(i\) 个环的方案数,有:
最后乘 \(n_3!\) 是因为我们填排列的时候可以任意排列这些边的顺序。
把 \(f, g, h\) 做加法卷积即可。
时间复杂度 \(O(n^2)\)。
code
// Problem: E. Complete the Permutations
// Contest: Codeforces - Codeforces Round 372 (Div. 1)
// URL: https://codeforces.com/problemset/problem/715/E
// Memory Limit: 256 MB
// Time Limit: 5000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 260;
const ll mod = 998244353;
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
ll n, a[maxn], b[maxn], c[maxn], fa[maxn], fac[maxn], ifac[maxn], S[maxn][maxn];
ll f[maxn], g[maxn], h[maxn], d[maxn], ans[maxn];
int find(int x) {
return fa[x] == x ? x : fa[x] = find(fa[x]);
}
inline bool merge(int x, int y) {
x = find(x);
y = find(y);
if (x != y) {
fa[x] = y;
return 1;
} else {
return 0;
}
}
inline ll C(ll n, ll m) {
if (n < m || n < 0 || m < 0) {
return 0;
} else {
return fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
}
void solve() {
scanf("%lld", &n);
for (int i = 1; i <= n; ++i) {
scanf("%lld", &a[i]);
fa[i] = i;
}
int cnt = 0;
for (int i = 1; i <= n; ++i) {
scanf("%lld", &b[i]);
if (a[i] && b[i] && !merge(a[i], b[i])) {
++cnt;
}
}
for (int i = 1; i <= n; ++i) {
if (a[i]) {
a[i] = find(a[i]);
}
if (b[i]) {
b[i] = find(b[i]);
}
}
fac[0] = 1;
for (int i = 1; i <= n; ++i) {
fac[i] = fac[i - 1] * i % mod;
}
ifac[n] = qpow(fac[n], mod - 2);
for (int i = n - 1; ~i; --i) {
ifac[i] = ifac[i + 1] * (i + 1) % mod;
}
S[0][0] = 1;
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= i; ++j) {
S[i][j] = (S[i - 1][j] * (i - 1) % mod + S[i - 1][j - 1]) % mod;
}
}
int n1 = 0, n2 = 0, n3 = 0;
for (int i = 1; i <= n; ++i) {
if (a[i] && b[i]) {
continue;
}
if (a[i] && !b[i]) {
++c[a[i]];
++n1;
}
if (!a[i] && b[i]) {
++c[b[i]];
++n2;
}
if (!a[i] && !b[i]) {
++n3;
}
}
for (int i = 1; i <= n; ++i) {
if (c[i] == 2) {
--n1;
--n2;
++n3;
}
}
if (n3) {
for (int i = 0; i <= n1; ++i) {
for (int j = i; j <= n1; ++j) {
f[i] = (f[i] + C(n1, j) * S[j][i] % mod * fac[n1 - j + n3 - 1] % mod * ifac[n3 - 1] % mod) % mod;
}
}
for (int i = 0; i <= n2; ++i) {
for (int j = i; j <= n2; ++j) {
g[i] = (g[i] + C(n2, j) * S[j][i] % mod * fac[n2 - j + n3 - 1] % mod * ifac[n3 - 1] % mod) % mod;
}
}
} else {
for (int i = 0; i <= n1; ++i) {
f[i] = S[n1][i];
}
for (int i = 0; i <= n2; ++i) {
g[i] = S[n2][i];
}
}
for (int i = 0; i <= n3; ++i) {
h[i] = S[n3][i] * fac[n3] % mod;
}
// for (int i = 0; i <= n1; ++i) {
// printf("%lld ", f[i]);
// }
// putchar('\n');
// for (int i = 0; i <= n2; ++i) {
// printf("%lld ", g[i]);
// }
// putchar('\n');
// for (int i = 0; i <= n3; ++i) {
// printf("%lld ", h[i]);
// }
// putchar('\n');
for (int i = 0; i <= n1; ++i) {
for (int j = 0; j <= n2; ++j) {
d[i + j] = (d[i + j] + f[i] * g[j] % mod) % mod;
}
}
for (int i = 0; i <= n1 + n2; ++i) {
for (int j = 0; j <= n3; ++j) {
ans[i + j] = (ans[i + j] + d[i] * h[j] % mod) % mod;
}
}
// for (int i = 0; i <= n1 + n2 + n3; ++i) {
// printf("%lld ", ans[i]);
// }
// putchar('\n');
for (int i = 0; i < n; ++i) {
int k = n - i;
if (k < cnt) {
printf("0 ");
} else {
printf("%lld ", ans[k - cnt]);
}
}
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}