Codeforces 1608F. MEX Counting (3200)
题目描述
给定一个长度为 \(n\) 的序列 \(b\),需要计算满足下列条件的序列 \(a\) 的个数,答案对 \(998244353\) 取模。
- 序列 \(a\) 的长度为 \(n\);
- \(\forall i\in[1,n],0\le a_i\le n\);
- \(\forall i\in[1,n],|mex(a_1,a_2,\cdots,a_i)-b_i|\le k\)。
\(1\le n\le 2000,1\le k\le 50,b_i\in[-k,n+k]\)。
考虑 \(\text{dp}\)。
设 \(f_{i,j,k}\) 表示填完前 \(i\) 个数,\(\text{mex}\) 为 \(j\),有 \(k\) 种 \(>j\) 的值的方案数。
转移考虑两种情况。
- \(\text{mex}\) 不改变。此时 \(a_{i+1}\neq j\),那么考虑是否加入一个 \(>j\) 的新值。
- 加入新值。\(f_{i,j,k}\rightarrow f_{i+1,j,k+1}\)。
- 不加入新值。\((j+k)\cdot f_{i,j,k}\rightarrow f_{i+1,j,k}\)。
- \(\text{mex}\) 改变,此时 \(a_{i+1}=j\)。考虑枚举转移到的新 \(\text{mex}\) 为 \(x\)。则必定满足在原来 \(k\) 种 \(>j\) 的值中必定出现过 \(j+1\sim x-1\),那么这部分排列数即是 \(\frac{k!}{(k-(x-j-1))!}\)。所以 \(\forall x,|x-b_{i+1}|\le k,\frac{k!}{(k-(x-j-1))!}\cdot f_{i,j,k}\rightarrow f_{i+1,x,k-(x-j-1)}\)。
时间复杂度是 \(O(n^2k^2)\),考虑优化 \(\text{mex}\) 改变的地方。可以发现 \(f_{i+1,y,k-(x-j-1)}\) 往后转移时会乘 \((k-(x-j-1))!\),与当前转移过去抵消。
于是可以设 \(g_{i,j,k}=k!f_{i,j,k}\)。那么转移变成
- \((j+k)\cdot g_{i,j,k}\rightarrow g_{i+1,j,k}\)。
- \((k+1)\cdot g_{i,j,k}\rightarrow g_{i+1,j,k+1}\)。
- \(\forall x,|x-b_{i+1}|\le k,g_{i,j,k}\rightarrow g_{i+1,x,k-(x-j-1)}\)。
将第 \(3\) 个转移改写为 \(g_{i,x,s}=g_{i,x-1,s+1}+g_{i-1,x-1,s}\)。
答案即为 \(\sum\limits_{x,s}\binom{n-x}{s}\cdot g_{n,x,s}\)。
滚动数组优化。时间复杂度 \(O(n^2k)\),空间复杂度 \(O(nk)\) 或 \(O(n^2)\)。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 2005, mod = 998244353;
int n, m, ans, a[N], l[N], r[N], f[2][N][N]; ll fac[N], inv[N];
inline ll C(int n, int m) { return fac[n] * inv[m] % mod * inv[n - m] % mod; }
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++ i) scanf("%d", a + i), l[i] = max(a[i] - m, 0), r[i] = min(a[i] + m, i);
fac[0] = inv[0] = inv[1] = 1;
for (int i = 1; i < N; ++ i) fac[i] = fac[i - 1] * i % mod;
for (int i = 2; i < N; ++ i) inv[i] = (mod - mod / i) * inv[mod % i] % mod;
for (int i = 1; i < N; ++ i) inv[i] = inv[i] * inv[i - 1] % mod;
f[0][0][0] = 1;
for (int i = 1; i <= n; ++ i) {
int p = i & 1, q = p ^ 1;
for (int k = 0; k <= i - l[i]; ++ k)
for (int j = max(l[i] - r[i - 1] - 1, 0); j < l[i] - l[i - 1]; ++ j)
(f[p][l[i]][k] += f[p ^ 1][l[i] - j - 1][j + k]) %= mod;
for (int j = l[i] + 1; j <= r[i]; ++ j)
for (int k = 0; k <= i - j; ++ k)
f[p][j][k] = (f[p][j - 1][k + 1] + f[p ^ 1][j - 1][k]) % mod;
for (int j = l[i - 1]; j <= r[i - 1]; ++ j)
for (int k = 0; k < i - j; ++ k)
if (l[i] <= j && j <= r[i])
f[p][j][k] = (f[p][j][k] + (ll)(j + k) * f[q][j][k]) % mod,
f[p][j][k + 1] = (f[p][j][k + 1] + (ll)(k + 1) * f[q][j][k]) % mod;
for (int j = l[i - 1]; j <= r[i - 1]; ++ j)
for (int k = 0; k < i - j; ++ k) f[q][j][k] = 0;
}
for (int i = 0; i <= n; ++ i)
for (int j = 0; j <= n - i; ++ j)
(ans += C(n - i, j) * f[n & 1][i][j] % mod) %= mod;
return printf("%d\n", ans), 0;
}