「codeforces - 1608F」MEX counting

link。

首先考虑暴力,枚举规划前缀 \([1, i]\) 和前缀 mex \(x\),则我们需要 \(x\) 个数来填了 \([0, x)\),还剩下 \(i-x\) 个数随便填 \([0, x) \cup (x, n]\),如果直接组合数算可能会出问题,考虑 dp。

定义 \(f[i][x][j]\) 表示规划前缀 \([1, i]\),当前的 mex 为 \(x\),还有 \(j\) 个数当前不对 mex 的取值造成影响(也就是说他们都大于 \(x\),这 \(j\) 个数是指在 \(a\) 数组中的,所以我们不必关心这 \(j\) 个数具体是什么)。转移就分两种情况:

  • \(a[i+1] \neq x\)\((i+1, x, j) \gets (i+1, x, j)+(i, x, j)*(x+j)\)\((i+1, x, j+1) \gets (i, x, j)\),第一个就是考虑当加入的 \(a[i+1]\) 属于那 \(j\) 个数或者属于 \([0, x)\),第二个很简单;
  • \(a[i+1] = x\):设当前 mex 变成了 \(y\),则有 \((i+1, y, j-y+x+1) \gets (i+1, y, j-y+x+1)+(i, x, j) \times \binom{j}{y-x-1} \times (y-x-1)!\),意义明显,注意后面那个是排列数而不是组合数。

然后这个是 \(O(n^2k^2)\) 的,把刷表改成填表后前缀和优化即可。

using modint = modint998244353;
int n, K, a[2100];
modint dp[2][2100][2100], sum[2][2100][2100], fct[2100], ifct[2100];
inline int L(int i) { return max(0, a[i]-K); }
inline int R(int i) { return min(i, a[i]+K); }
signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    fct[0] = 1;
    for (int i=1; i<2100; ++i) {
        fct[i] = fct[i-1]*i;
    }
    ifct[2099] = fct[2099].inv();
    for (int i=2098; i>=0; --i) {
        ifct[i] = ifct[i+1]*(i+1);
    }
    cin >> n >> K;
    for (int i=1; i<=n; ++i) {
        cin >> a[i];
    }
    dp[0][0][0] = sum[0][0][0] = 1;
    for (int i=1,cur=1; i<=n; ++i,cur^=1) {
        for (int j=0; j<=i; ++j) {
            for (int x=L(i); x<=R(i) && x<=j; ++x) {
                dp[cur][x][j] = dp[cur^1][x][j]*j;
                if (j) {
                    dp[cur][x][j] += dp[cur^1][x][j-1];
                }
                if (j && x) {
                    dp[cur][x][j] += sum[cur^1][min(x-1, R(i-1))][j-1]*ifct[j-x];
                }
                sum[cur][x][j] = dp[cur][x][j]*fct[j-x];
                if (x) {
                    sum[cur][x][j] += sum[cur][x-1][j];
                }
            }
        }
        for (int j=0; j<=i-1; ++j) {
            for (int x=L(i-1); x<=R(i-1) && x<=j; ++x) {
                dp[cur^1][x][j] = sum[cur^1][x][j] = 0;
            }
        }
    }
    modint ans = 0;
    for (int i=0; i<=n; ++i) {
        for (int j=L(n); j<=R(n) && j<=i; ++j) {
            ans += dp[n&1][j][i]*fct[n-j]*ifct[n-i];
        }
    }
    cout << ans.val() << "\n";
}
posted @ 2022-07-26 21:28  cirnovsky  阅读(51)  评论(1编辑  收藏  举报