[ABC292G] Count Strictly Increasing Sequences

题意

给定 \(n\) 个由 \(m\) 个字符组成的字符串,你需要将里面所有 \(?\) 替换为 \([0, 9]\),并使得字符串的字典序单调递增。

\(n, m \le 40\)

Sol

相邻限制,想到区间 dp。

由于显然对于同一位 \(i\)\([l, r]\) 在第 \(i\) 为上的数字一定是一段 \(0\),然后一段 \(1\),然后一段 \(2\),...,然后一段 \(9\),考虑枚举 \(k \in [l, r]\),钦定 \([l, k]\) 为当前的数字 \(j\)\([k + 1, r]\) 则由 \(j' \ge j\) 来填。

因此状态定义十分显然了,设 \(f_{i, j, l, r}\) 表示使用第 \(i\) 位,限制当前放的数字 \(\ge j\),满足原字符串序列中 \([l, r]\) 的偏序关系的方案数。

\[f_{i, j, l, r} = \sum_k f_{i + 1, 0, l, k} \times f_{i, j + 1, k + 1, r} \]

以及:

\[f_{i, j, l, r} = f_{i, j + 1, l, r} \]

直接数位 dp 记忆化搜索即可。

时间复杂度 \(O(n ^ 3 m c)\)

Code

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <array>
using namespace std;
#ifdef ONLINE_JUDGE

/* #define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++) */
/* char buf[1 << 23], *p1 = buf, *p2 = buf, ubuf[1 << 23], *u = ubuf; */

#endif
int read() {
    int p = 0, flg = 1;
    char c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-') flg = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        p = p * 10 + c - '0';
        c = getchar();
    }
    return p * flg;
}
void write(int x) {
    if (x < 0) {
        x = -x;
        putchar('-');
    }
    if (x > 9) {
        write(x / 10);
    }
    putchar(x % 10 + '0');
}
bool _stmer;

const int N = 45, mod = 998244353;

array <string, N> mp;
char strbuf[N];

array <array <array <array <int, N>, N>, 11>, N> f;

void Mod(int &x) {
    if (x >= mod) x -= mod;
    if (x < 0) x += mod;
}

int dfs(int x, int y, int l, int r, int m) {
    if (l > r || (l == r && x > m)) return 1;
    if (x > m || y > 9) return 0;
    if (~f[x][y][l][r]) return f[x][y][l][r];
    int ans = dfs(x, y + 1, l, r, m);
    for (int k = l; k <= r; k++) {
        if (mp[k][x] != '?' && mp[k][x] != y + '0') break;
        ans += 1ll * dfs(x + 1, 0, l, k, m) * dfs(x, y + 1, k + 1, r, m) % mod, Mod(ans);
    }
    return f[x][y][l][r] = ans;
}

bool _edmer;
int main() {
    cerr << (&_stmer - &_edmer) / 1024.0 / 1024.0 << "MB\n";
    int n = read(), m = read();

    for (int i = 0; i < N; i++)
        for (int j = 0; j < 10; j++)
            for (int k = 0; k < N; k++)
                f[i][j][k].fill(-1);

    for (int i = 1; i <= n; i++) {
        scanf("%s", strbuf);
        mp[i] = strbuf, mp[i] = " " + mp[i];
    }
    write(dfs(1, 0, 1, n, m)), puts("");
    return 0;
}
posted @ 2024-10-05 10:12  cxqghzj  阅读(5)  评论(0编辑  收藏  举报