【学习笔记】DP 套 DP

DP 套 DP 的题目一般大致要求为求满足某个要求的元素数量。而对于其要求的判定需要 DP 来解决。因此我们将判定 DP 的结果作为计数 DP 的状态来进行计数。

[TJOI2018] 游园会

首先考虑如何求出两个字符串 \(S, T\) 的 LCS,设 \(f_{i, j}\) 表示 \(S\left[1, i\right]\)\(T\left[1, j\right]\) 的 LCS 长度。发现其有如下转移:

\[f_{i, j} = \begin{cases} f_{i - 1, j - 1} + 1 && S_i = T_j \\ \max\left\{f_{i, j - 1}, f_{i- 1, j}\right\} && S_i \neq T_j \end{cases}\]

我们考虑对上述判定 DP 的过程进行计数,具体的,将上述判定 DP 结果相同的前缀作为相同的子问题进行合并后计数以优化复杂度,但是若直接对于某个 \(i\) 将所有的 \(f_{i, j}\) 值均压入状态那么状态数为 \(\mathcal{O}(N \times K^K)\),无法接受,需要进一步发掘性质。

我们可以发现,对于某个 \(i\),我们有 \(f_{i, j} - f_{i, j - 1} \le 1\),即 \(f\) 数组的差分值值域为 \(\left[0, 1\right]\),若将查分数组压入状态那么复杂度是 \(\mathcal{O}(N \times 2^K)\) 级别的,可以接受。

进而我们可以预处理出判定 DP 的转移边,即对于所有可能的 \(f_{i, *}\),枚举 \(S_{i + 1}\) 的值并计算得到的 \(f_{i + 1, *}\)。接下来进行计数,设 \(g_{i, S}\) 表示长度为 \(i\) 的,使得 \(f_{i}\) 差分值在 \(S\) 处为 \(1\) 的字符串数量,转移时枚举所有合法的下一个字符并预处理的转移边进行转移即可。

考虑如何处理其中不能出现 \(\tt{NOI}\) 子串的限制,在我们的计数 DP 中额外维护一维代表其目前与 \(\tt{NOI}\) 匹配的长度即可,注意这里的匹配要求必须选择最后一个字符,例如字符串 \(\tt{NONONONONONONONONON}\) 的匹配长度为 \(1\)

至此我们便可以通过此题,复杂度为 \(\mathcal{O}(N 2^K)\)

Code
#include <bits/stdc++.h>

typedef int valueType;
typedef std::vector<valueType> ValueVector;
typedef std::vector<ValueVector> ValueMatrix;

namespace MODINT_WITH_FIXED_MOD {
    constexpr valueType MOD = 1e9 + 7;

    template<typename T1, typename T2>
    void Inc(T1 &a, T2 b) {
        a = a + b;

        if (a >= MOD)
            a -= MOD;
    }

    template<typename T1, typename T2>
    void Dec(T1 &a, T2 b) {
        a = a - b;

        if (a < 0)
            a += MOD;
    }

    template<typename T1, typename T2>
    T1 sum(T1 a, T2 b) {
        return a + b >= MOD ? a + b - MOD : a + b;
    }

    template<typename T1, typename T2>
    T1 sub(T1 a, T2 b) {
        return a - b < 0 ? a - b + MOD : a - b;
    }

    template<typename T1, typename T2>
    T1 mul(T1 a, T2 b) {
        return (long long) a * b % MOD;
    }

    template<typename T1, typename T2>
    void Mul(T1 &a, T2 b) {
        a = (long long) a * b % MOD;
    }

    template<typename T1, typename T2>
    T1 pow(T1 a, T2 b) {
        T1 result = 1;

        while (b > 0) {
            if (b & 1)
                Mul(result, a);

            Mul(a, a);
            b = b >> 1;
        }

        return result;
    }
} // namespace MODINT_WITH_FIXED_MOD

using namespace MODINT_WITH_FIXED_MOD;

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);

    valueType N, K;

    std::cin >> N >> K;

    std::string S_;

    std::cin >> S_;

    ValueVector Table(1 << 8, -1);

    Table['N'] = 0;
    Table['O'] = 1;
    Table['I'] = 2;

    ValueVector S(S_.begin(), S_.end());

    for (auto &x : S)
        x = Table[x];

    ValueMatrix Transfer(1 << K, ValueVector(3, -1));

    for (valueType s = 0; s < (1 << K); ++s) {
        ValueVector prev(K + 1, 0);

        for (valueType i = 1; i <= K; ++i)
            prev[i] = (s >> (i - 1)) & 1;

        std::partial_sum(prev.begin(), prev.end(), prev.begin());

        for (valueType c = 0; c < 3; ++c) {
            ValueVector next(K + 1, 0);

            for (valueType i = 1; i <= K; ++i) {
                next[i] = std::max(prev[i], next[i - 1]);

                if (S[i - 1] == c)
                    next[i] = std::max(next[i], prev[i - 1] + 1);
            }

            std::adjacent_difference(next.begin(), next.end(), next.begin());

            valueType t = 0;

            for (valueType i = 1; i <= K; ++i)
                t |= next[i] << (i - 1);

            Transfer[s][c] = t;
        }
    }

    std::array<ValueMatrix, 2> F;

    F[0].resize(1 << K, ValueVector(3, 0));
    F[1].resize(1 << K, ValueVector(3, 0));

    F[0][0][0] = 1;

    for (valueType i = 1; i <= N; ++i) {
        valueType const now = i & 1, prev = now ^ 1;

        for (auto &v : F[now])
            std::fill(v.begin(), v.end(), 0);

        for (valueType s = 0; s < (1 << K); ++s) {
            for (valueType c = 0; c < 3; ++c) {
                valueType const t = Transfer[s][c];

                if (c == 0) { // N
                    Inc(F[now][t][1], F[prev][s][0]);
                    Inc(F[now][t][1], F[prev][s][1]);
                    Inc(F[now][t][1], F[prev][s][2]);
                }

                if (c == 1) { // O
                    Inc(F[now][t][0], F[prev][s][0]);
                    Inc(F[now][t][2], F[prev][s][1]);
                    Inc(F[now][t][0], F[prev][s][2]);
                }

                if (c == 2) { // I
                    Inc(F[now][t][0], F[prev][s][0]);
                    Inc(F[now][t][0], F[prev][s][1]);
                }
            }
        }
    }

    ValueVector Ans(K + 1, 0);

    for (valueType s = 0; s < (1 << K); ++s) {
        valueType const popcount = __builtin_popcountll(s);

        for (valueType c = 0; c < 3; ++c)
            Inc(Ans[popcount], F[N & 1][s][c]);
    }

    for (valueType i = 0; i <= K; ++i)
        std::cout << Ans[i] << std::endl;

    return 0;
}

CF979E Kuro and Topological Parity

我们首先考虑在确定节点颜色的情况下如何计数,设 \(f_{u, 0 / 1, 0 / 1}\) 表示考虑标号不大于 \(u\) 的所有点,以 \(u\) 结尾的合法路径条数模 \(2\) 后的值为 \(0 / 1\),且好的合法路径条数总数模 \(2\) 后的值为 \(0 / 1\) 的方案数。我们对于某个节点 \(u\),若以 \(u\) 结尾的合法路径条数模 \(2\) 后的值为 \(1\),那么我们称之为奇点,反之为偶点。那么对于上述 DP 时在转移时枚举异色奇点有多少个与之相连即可。

不难发现影响路径总数奇偶性的是奇点个数的奇偶性,而决定一个点奇偶性的是与之相连的异色奇数点个数,这启示我们将黑色奇点和白色奇点的个数作为状态进行奇数,设 \(f_{i, j, k}\) 表示考虑标号不大于 \(i\) 的所有点,黑色奇点的个数为 \(j\),白色奇点的个数为 \(k\) 的方案数。每次转移时考虑新增节点的颜色和其奇偶性即可。

现在还剩一个问题,若我们希望新增点数为奇点,那么我们便要选择偶数个异色奇点(该新增点自身为一条路径),设有 \(m\) 个异色奇点,那么其方案数为:

\[\sum\limits_{i \,\text{is even}}\dbinom{m}{i} \]

我们可以证明其为 \(2^{m - 1}\),具体的,考虑选偶数个和选奇数个的方案数之差,我们有:

\[\begin{aligned} &\sum\limits_{i \,\text{is even}}\dbinom{m}{i} - \sum\limits_{i \,\text{is odd}}\dbinom{m}{i}\\ =&\sum\limits_{i = 0}^{m}\left(-1\right)^i\dbinom{m}{i} \\ =&\sum\limits_{i = 0}^{m}1^{m - i}\left(-1\right)^i\dbinom{m}{i} \\ =&\left(1 - 1\right)^m\\ =&\left[m = 0\right] \end{aligned}\]

因此直接进行转移即可,复杂度为 \(\mathcal{O}(n^3)\)

Code
#include <bits/stdc++.h>

typedef int valueType;
typedef std::vector<valueType> ValueVector;
typedef std::vector<ValueVector> ValueMatrix;
typedef std::vector<ValueMatrix> ValueCube;

namespace MODINT_WITH_FIXED_MOD {
    constexpr valueType MOD = 1e9 + 7;

    template<typename T1, typename T2>
    void Inc(T1 &a, T2 b) {
        a = a + b;

        if (a >= MOD)
            a -= MOD;
    }

    template<typename T1, typename T2>
    void Dec(T1 &a, T2 b) {
        a = a - b;

        if (a < 0)
            a += MOD;
    }

    template<typename T1, typename T2>
    T1 sum(T1 a, T2 b) {
        return a + b >= MOD ? a + b - MOD : a + b;
    }

    template<typename T1, typename T2>
    T1 sub(T1 a, T2 b) {
        return a - b < 0 ? a - b + MOD : a - b;
    }

    template<typename T1, typename T2>
    T1 mul(T1 a, T2 b) {
        return (long long) a * b % MOD;
    }

    template<typename T1, typename T2>
    void Mul(T1 &a, T2 b) {
        a = (long long) a * b % MOD;
    }

    template<typename T1, typename T2>
    T1 pow(T1 a, T2 b) {
        T1 result = 1;

        while (b > 0) {
            if (b & 1)
                Mul(result, a);

            Mul(a, a);
            b = b >> 1;
        }

        return result;
    }
} // namespace MODINT_WITH_FIXED_MOD

using namespace MODINT_WITH_FIXED_MOD;

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);

    valueType N, P;

    std::cin >> N >> P;

    ValueCube F;

    F.resize(N + 1, ValueMatrix(N + 1, ValueVector(N + 1, 0)));

    ValueVector C(N + 1, 0);

    for (valueType i = 1; i <= N; ++i)
        std::cin >> C[i];

    if (N == 1) {
        if (P == 0)
            std::cout << 0 << std::endl;
        else if (C[1] == -1)
            std::cout << 2 << std::endl;
        else
            std::cout << 1 << std::endl;

        return 0;
    }

    // f_{i, j, k} i 个点, j 个奇黑, k 个奇白

    if (C[1] != 1)
        F[1][1][0] = 1;

    if (C[1] != 0)
        F[1][0][1] = 1;

    for (valueType i = 2; i <= N; ++i) {
        for (valueType j = 0; j < i; ++j) {
            for (valueType k = 0; j + k < i; ++k) { // 上一位的状态
                if (C[i] != 1) {                    // 0 : 黑
                    if (k > 0)
                        Inc(F[i][j + 1][k], mul(F[i - 1][j][k], ((1ll << (i - 2)) % MOD)));
                    else
                        Inc(F[i][j + 1][k], mul(F[i - 1][j][k], ((1ll << (i - 1)) % MOD)));

                    if (k > 0)
                        Inc(F[i][j][k], mul(F[i - 1][j][k], ((1ll << (i - 2)) % MOD)));
                }

                if (C[i] != 0) { // 1 : 白
                    if (j > 0)
                        Inc(F[i][j][k + 1], mul(F[i - 1][j][k], ((1ll << (i - 2)) % MOD)));
                    else
                        Inc(F[i][j][k + 1], mul(F[i - 1][j][k], ((1ll << (i - 1)) % MOD)));

                    if (j > 0)
                        Inc(F[i][j][k], mul(F[i - 1][j][k], ((1ll << (i - 2)) % MOD)));
                }
            }
        }
    }

    valueType ans = 0;

    for (valueType j = 0; j <= N; ++j)
        for (valueType k = 0; j + k <= N; ++k)
            if (((j + k) & 1) == P)
                Inc(ans, F[N][j][k]);

    std::cout << ans << std::endl;

    return 0;
}

[SDOI/SXOI2022] 小 N 的独立集

考虑在给定点权的情况下如何求最大独立集,设 \(f_{u, 0 / 1}\) 表示考虑 \(u\) 子树内的点,选 / 不选 \(u\) 的情况下的最大独立集权值。这样的话 DP 值为共有 \(\mathcal{O}(\left(nk\right)^2)\) 级别的。考虑优化,发现我们实际上只关心 \(\max\left\{f_{u, 0}, f_{u, 1}\right\}\)\(f_{u, 0}\) 的值,同时可以发现我们有 \(0 \le \max\left\{f_{u, 0}, f_{u, 1}\right\} - f_{u, 0} \le k\),这样我们的状态数就变为了 \(\mathcal{O}(nk^2)\) 级别。考虑上述两个值的含义,考虑设 \(f_{u, 0 / 1}\) 考虑 \(u\) 子树内的点,是否钦定不选择 \(u\) 的情况下的最大独立集权值。我们有转移:

\[\begin{aligned} f_{u, 0} &\leftarrow f_{v, 1} \\ f_{u, 1} &\leftarrow f_{v, 0} + k \\ \end{aligned}\]

下面考虑如何计数,设 \(g_{u, s, t}\) 表示考虑 \(u\) 子树内的点,满足 \(f_{u, 0} = s\)\(f_{u, 1} = t\) 的方案数,树上背包转移即可。复杂度 \(\mathcal{O}(n^2k^4)\)

Code
#include <bits/stdc++.h>

typedef int valueType;
typedef std::vector<valueType> ValueVector;
typedef std::vector<ValueVector> ValueMatrix;
typedef std::vector<ValueMatrix> ValueCube;

namespace MODINT_WITH_FIXED_MOD {
    constexpr valueType MOD = 1e9 + 7;

    template<typename T1, typename T2>
    void Inc(T1 &a, T2 b) {
        a = a + b;

        if (a >= MOD)
            a -= MOD;
    }

    template<typename T1, typename T2>
    void Dec(T1 &a, T2 b) {
        a = a - b;

        if (a < 0)
            a += MOD;
    }

    template<typename T1, typename T2>
    T1 sum(T1 a, T2 b) {
        return a + b >= MOD ? a + b - MOD : a + b;
    }

    template<typename T1, typename T2>
    T1 sub(T1 a, T2 b) {
        return a - b < 0 ? a - b + MOD : a - b;
    }

    template<typename T1, typename T2>
    T1 mul(T1 a, T2 b) {
        return (long long) a * b % MOD;
    }

    template<typename T1, typename T2>
    void Mul(T1 &a, T2 b) {
        a = (long long) a * b % MOD;
    }

    template<typename T1, typename T2>
    T1 pow(T1 a, T2 b) {
        T1 result = 1;

        while (b > 0) {
            if (b & 1)
                Mul(result, a);

            Mul(a, a);
            b = b >> 1;
        }

        return result;
    }
} // namespace MODINT_WITH_FIXED_MOD

using namespace MODINT_WITH_FIXED_MOD;

valueType N, K;
ValueMatrix G;
ValueCube F;
ValueVector Size;

void dfs(valueType x, valueType from) {
    Size[x] = K;

    for (valueType k = 1; k <= K; ++k)
        Inc(F[x][0][k], 1);

    for (auto const &to : G[x]) {
        if (to == from)
            continue;

        dfs(to, x);

        ValueMatrix Next(Size[x] + Size[to] + 1, ValueVector(K + 1, 0));

        for (valueType t = 0; t <= Size[x]; ++t) {
            for (valueType d_t = 0; d_t <= K && d_t + t <= Size[x]; ++d_t) {
                if (F[x][t][d_t] == 0)
                    continue;

                for (valueType s = 0; s <= Size[to]; ++s) {
                    for (valueType d_s = 0; d_s <= d_t && d_s + s <= Size[to]; ++d_s) {
                        Inc(Next[s + d_s + t][d_t - d_s], mul(F[x][t][d_t], F[to][s][d_s]));
                    }

                    for (valueType d_s = d_t + 1; d_s <= K && d_s + s <= Size[to]; ++d_s) {
                        Inc(Next[s + d_s + t][0], mul(F[x][t][d_t], F[to][s][d_s]));
                    }
                }
            }
        }

        F[x].swap(Next);

        Size[x] += Size[to];
    }
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);

    std::cin >> N >> K;

    G.resize(N + 1);

    for (valueType i = 1; i < N; ++i) {
        valueType u, v;

        std::cin >> u >> v;

        G[u].push_back(v);
        G[v].push_back(u); // return 0;
    }

    F.resize(N + 1, ValueMatrix(N * K + 1, ValueVector(K + 1, 0)));
    Size.resize(N + 1, 0);

    dfs(1, 0);

    ValueVector Ans(N * K + 1, 0);

    for (valueType s = 1; s <= N * K; ++s)
        for (valueType d = 0; d <= K && s + d <= N * K; ++d)
            Inc(Ans[s + d], F[1][s][d]);

    for (valueType i = 1; i <= N * K; ++i)
        std::cout << Ans[i] << '\n';

    std::cout << std::flush;

    std::exit(0);
}

CF1784E Infinite Game

首先考虑若确定 \(s\) 后如何计算答案。

发现比分只有 \(\left(0 : 0\right), \left(0 : 1\right), \left(1 : 0\right), \left(1 : 1\right)\) 四种状态。我们不妨对于每种状态以其作为初始状态来按 \(s\) 进行一轮游戏并得到终点状态和在这一轮游戏中 Alice 的得分与 Bob 得分的差以作为边权,按其建边可以得到一个内向基环树森林。由于我们求的是比分之比的极限,因此我们只需要考虑从 \(\left(0 : 0\right)\) 出发可以到达的环上的边权即可。

考虑对这个内向基环树森林进行计数,考虑到只有四条边和四个节点,因此我们考虑将其压入状态,设 \(f\left({i, \left\{u_0, u_1, u_2, u_3\right\}, \left\{w_0, w_1, w_2, w_3\right\}}\right)\) 表示考虑 \(S\left[1, i\right]\) 且四个状态指向的状态依次为 \(u_0, u_1, u_2, u_3\) 且边权依次为 \(w_0, w_1, w_2, w_3\) 的方案数。不难发现这个 DP 的时间复杂度为 \(\mathcal{O}(n^4)\),无法接受。考虑如何优化。

发现影响复杂度的主要是对边权的统计,考虑压缩这些状态。发现我们实际上要求的是环上的边权和,又考虑到节点个数很少,因此我们可以枚举环上的节点然后只记录环上的边权和。设 \(f\left(i, \left\{u_0, u_1, u_2, u_3\right\}, s\right)\) 表示考虑 \(S\left[1, i\right]\) 且四个状态指向的状态依次为 \(u_0, u_1, u_2, u_3\) 且环上边权之和为 \(s\) 的方案数,这样通过预处理转移边和边权即可实现快速转移。

复杂度为 \(\mathcal{O}(n^2)\),常数大约为 \(2^4 \times 4^4\)

Code
#include <bits/stdc++.h>

typedef int valueType;
typedef std::vector<valueType> ValueVector;
typedef std::vector<ValueVector> ValueMatrix;

namespace MODINT_WITH_FIXED_MOD {
    constexpr valueType MOD = 998244353;

    template<typename T1, typename T2>
    void Inc(T1 &a, T2 b) {
        a = a + b;

        if (a >= MOD)
            a -= MOD;
    }

    template<typename T1, typename T2>
    void Dec(T1 &a, T2 b) {
        a = a - b;

        if (a < 0)
            a += MOD;
    }

    template<typename T1, typename T2>
    T1 sum(T1 a, T2 b) {
        return a + b >= MOD ? a + b - MOD : a + b;
    }

    template<typename T1, typename T2>
    T1 sub(T1 a, T2 b) {
        return a - b < 0 ? a - b + MOD : a - b;
    }

    template<typename T1, typename T2>
    T1 mul(T1 a, T2 b) {
        return (long long) a * b % MOD;
    }

    template<typename T1, typename T2>
    void Mul(T1 &a, T2 b) {
        a = (long long) a * b % MOD;
    }

    template<typename T1, typename T2>
    T1 pow(T1 a, T2 b) {
        T1 result = 1;

        while (b > 0) {
            if (b & 1)
                Mul(result, a);

            Mul(a, a);
            b = b >> 1;
        }

        return result;
    }
} // namespace MODINT_WITH_FIXED_MOD

using namespace MODINT_WITH_FIXED_MOD;

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);

    std::string S;

    std::cin >> S;

    valueType const N = S.length();

    ValueVector Ans(3, 0);

    for (valueType circle = 1; circle < (1 << 4); ++circle) {
        ValueMatrix NextWeight(1 << 8, ValueVector(2, 0)), NextState(1 << 8, ValueVector(2, 0));

        for (valueType state = 0; state < (1 << 8); ++state) {
            for (valueType c = 0; c < 2; ++c) {
                for (valueType i = 0; i < 4; ++i) {
                    valueType to = (state >> (2 * i)) & 3;

                    if (c == 0) {
                        if (to & 1) {              //  已有分数,获胜一局
                            if ((circle >> i) & 1) // 在环中
                                ++NextWeight[state][c];

                            to = 0;
                        } else { // 获得一分
                            to |= 1;
                        }
                    } else { // c == 1
                        if (to & 2) {
                            if ((circle >> i) & 1)
                                --NextWeight[state][c];

                            to = 0;
                        } else {
                            to |= 2;
                        }
                    }

                    NextState[state][c] |= to << (2 * i);
                }
            }
        }

        valueType const Bound = (N + 1) / 2 * 2 * __builtin_popcount(circle);

        ValueMatrix F((1 << 8), ValueVector(2 * Bound + 1));

        F[3 << 6 | 2 << 4 | 1 << 2 | 0 << 0][Bound] = 1;

        for (valueType i = 0; i < N; ++i) {
            ValueMatrix Next((1 << 8), ValueVector(2 * Bound + 1, 0));

            for (valueType state = 0; state < (1 << 8); ++state) {
                for (valueType sum = 0; sum <= 2 * Bound; ++sum) {
                    if (F[state][sum] == 0)
                        continue;

                    for (valueType c = 0; c < 2; ++c) {
                        if (c == 0 && S[i] == 'b')
                            continue;

                        if (c == 1 && S[i] == 'a')
                            continue;

                        Inc(Next[NextState[state][c]][sum + NextWeight[state][c]], F[state][sum]);
                    }
                }
            }

            F.swap(Next);
        }

        for (valueType state = 0; state < (1 << 8); ++state) { // 检查是否在环上
            ValueVector To(4);

            for (valueType i = 0; i < 4; ++i)
                To[i] = (state >> (2 * i)) & 3;

            ValueVector Path({0});

            while (true) {
                valueType const x = Path.back();

                if (std::find(Path.begin(), Path.end(), To[x]) != Path.end()) {
                    Path.erase(Path.begin(), std::find(Path.begin(), Path.end(), To[x]));

                    break;
                } else {
                    Path.push_back(To[x]);
                }
            }

            valueType realCircle = 0;

            for (auto const &x : Path)
                realCircle |= 1 << x;

            if (realCircle != circle)
                continue;

            for (valueType sum = 0; sum <= 2 * Bound; ++sum) {
                if (sum == Bound) {
                    Inc(Ans[1], F[state][sum]);
                } else if (sum > Bound) {
                    Inc(Ans[0], F[state][sum]);
                } else {
                    Inc(Ans[2], F[state][sum]);
                }
            }
        }
    }

    std::cout << Ans[0] << std::endl;

    std::cout << Ans[1] << std::endl;

    std::cout << Ans[2] << std::endl;

    return 0;
}

[ZJOI2019] 麻将

首先给出一些定义:

  • 顺子:三张大小相邻的牌,例如 \(i, i + 1, i + 2\),其中 \(1, \le i \le n - 2\)

  • 刻子:三张大小相同的牌,例如 \(i, i, i\),其中 \(1 \le i \le n\)

  • 面子:顺子和刻子的统称。

  • 对子:两张大小相同的牌,例如 \(i, i\),其中 \(1 \le i \le n\)


我们首先考虑如何判断是否胡牌,对于第二种胡牌方法的判断是简单的,主要考虑第一种。

可以发现,若我们最终的组合方式中存在三个顺子,那么我们可以将其转化为三个刻字。因此我们只需要考虑顺子个数小于三个的情况,考虑设 \(f_{i, j, k}\) 表示考虑大小小于 \(i\) 的牌,其中存在 \(j\) 个顺子 \(\left(i - 2, i - 1, i\right)\)\(k\) 个顺子 \(\left(i - 1, i, i + 1\right)\) 的情况下可以得到的最多面子数。转移时枚举顺子 \(\left(i, i + 1, i + 2\right)\) 的个数 \(l\),若 \(j + k + l\) 超过了第 \(i\) 种牌的数量则非法,否则将剩余的第 \(i\) 种牌全部转化为刻子。进而我们得到了一个判定 DP。

对于牌组中是否存在对子的判断是简单的,在 \(f\) 中增加一维表示是否存在对子即可。对于第二种胡牌方式,可以在 \(f\) 中再增加一维表示可以得到的最大的对子个数(牌组与第一组胡牌方式独立,即一张牌可以同时在该维和上述转移中产生贡献)。

可以发现上述 DP 的状态数不多,考虑建出自动机后进行计数。暴力进行转移后得到的自动机大约是 \(3956\) 个节点,因此我们可以进行 DP。

考虑到胡牌轮数的期望难以计算,可以转化为求截至第 \(i\) 轮尚未胡牌的概率,进而可以计数。

\(g_{i, j, k}\) 表示考虑标号不大于 \(i\) 的牌,选了 \(j\) 个且当前自动机状态为 \(k\) 的方案数。转移时枚举第 \(i + 1\) 种牌拿几个即可。

复杂度 \(\mathcal{O}(n^2S)\),其中 \(S = 3956\)

Code
#include <bits/stdc++.h>

typedef int valueType;
typedef std::vector<valueType> ValueVector;
typedef std::vector<ValueVector> ValueMatrix;
typedef std::vector<ValueMatrix> ValueCube;
typedef std::vector<bool> bitset;

namespace MODINT_WITH_FIXED_MOD {
    constexpr valueType MOD = 998244353;

    template<typename T1, typename T2>
    void Inc(T1 &a, T2 b) {
        a = a + b;

        if (a >= MOD)
            a -= MOD;
    }

    template<typename T1, typename T2>
    void Dec(T1 &a, T2 b) {
        a = a - b;

        if (a < 0)
            a += MOD;
    }

    template<typename T1, typename T2>
    T1 sum(T1 a, T2 b) {
        return a + b >= MOD ? a + b - MOD : a + b;
    }

    template<typename T1, typename T2>
    T1 sub(T1 a, T2 b) {
        return a - b < 0 ? a - b + MOD : a - b;
    }

    template<typename T1, typename T2>
    T1 mul(T1 a, T2 b) {
        return (long long) a * b % MOD;
    }

    template<typename T1, typename T2>
    void Mul(T1 &a, T2 b) {
        a = (long long) a * b % MOD;
    }

    template<typename T1, typename T2>
    T1 pow(T1 a, T2 b) {
        T1 result = 1;

        while (b > 0) {
            if (b & 1)
                Mul(result, a);

            Mul(a, a);
            b = b >> 1;
        }

        return result;
    }
} // namespace MODINT_WITH_FIXED_MOD

using namespace MODINT_WITH_FIXED_MOD;

class BinomialCoefficient {
private:
    valueType N;
    ValueVector Fact_, InvFact_;

public:
    BinomialCoefficient() = default;

    BinomialCoefficient(valueType n) : N(n), Fact_(N + 1, 1), InvFact_(N + 1, 1) {
        for (valueType i = 1; i <= N; ++i)
            Fact_[i] = mul(Fact_[i - 1], i);

        InvFact_[N] = pow(Fact_[N], MOD - 2);

        for (valueType i = N - 1; i >= 0; --i)
            InvFact_[i] = mul(InvFact_[i + 1], i + 1);
    }

    valueType operator()(valueType n, valueType m) const {
        if (n < 0 || m < 0 || n < m)
            return 0;

        if (m > N)
            throw std::out_of_range("BinomialCoefficient::operator() : m > N");

        if (n <= N)
            return mul(Fact_[n], mul(InvFact_[m], InvFact_[n - m]));

        valueType result = 1;

        for (valueType i = 0; i < m; ++i)
            Mul(result, n - i);

        Mul(result, InvFact_[m]);

        return result;
    }

    valueType Fact(valueType n) const {
        return Fact_[n];
    }
};

class Mahjong {
private:
    class State {
    protected:
        ValueMatrix F;

    public:
        State() : F(3, ValueVector(3, -1)) {
            // F = ValueMatrix(3, ValueVector(3, -1));
        }

        void SetHu() {
            // F = ValueMatrix(3, ValueVector(3, -1));
            std::abort();
        }

        void SetFirst() {
            // F = ValueMatrix(3, ValueVector(3, 0));
            // F = ValueMatrix(3, ValueVector(3, -1));
            for (auto &v : F)
                std::fill(v.begin(), v.end(), -1);

            F[0][0] = 0;
        }

        void SetSecond() {
            // F = ValueMatrix(3, ValueVector(3, -1));

            for (auto &v : F)
                std::fill(v.begin(), v.end(), -1);
        }

        bool CheckHu() const {
            for (valueType i = 0; i < 3; ++i)
                for (valueType j = 0; j < 3; ++j)
                    if (F[i][j] >= 4)
                        return true;

            return false;
        }

    public:
        friend bool operator<(State const &a, State const &b) {
            //     for (valueType i = 0; i < 3; ++i)
            //         for (valueType j = 0; j < 3; ++j)
            //             if (a.F[i][j] != b.F[i][j])
            //                 return a.F[i][j] < b.F[i][j];

            //     return false;
            return a.F < b.F;
        }

        friend bool operator==(State const &a, State const &b) {
            return a.F == b.F;
        }

        friend State operator+(State const &S, valueType count) {
            State T;

            for (valueType i = 0; i < 3; ++i) {
                for (valueType j = 0; j < 3; ++j) {
                    if (S.F[i][j] == -1)
                        continue;

                    for (valueType k = 0; k < 3 && i + j + k <= count; ++k)
                        T.F[j][k] = std::max(T.F[j][k], std::min<valueType>(4, S.F[i][j] + i + (count - i - j - k) / 3));
                }
            }

            return T;
        }

        friend State Merge(State const &A, State const &B) {
            State result;

            for (valueType i = 0; i < 3; ++i)
                for (valueType j = 0; j < 3; ++j)
                    result.F[i][j] = std::max(A.F[i][j], B.F[i][j]);

            return result;
        }
    };

private:
    std::pair<State, State> state;
    valueType pairCount;

public:
    Mahjong() {
        SetInit();
    };

    void SetHu() {
        state.first.SetHu();
        state.second.SetHu();

        pairCount = -1;
    }

    void SetInit() {
        state.first.SetFirst();
        state.second.SetSecond();

        pairCount = 0;
        // pairCount = -1;
    }

    bool CheckHu() {
        if (pairCount >= 7 || state.second.CheckHu()) {
            // SetHu();

            return true;
        } else {
            return false;
        }
    }

public:
    friend bool operator<(Mahjong const &a, Mahjong const &b) {
        if (a.pairCount == b.pairCount)
            return a.state < b.state;

        return a.pairCount < b.pairCount;
    }

    friend bool operator==(Mahjong const &a, Mahjong const &b) {
        return a.pairCount == b.pairCount && a.state == b.state;
    }

    friend Mahjong operator+(Mahjong const &S, valueType count) {
        Mahjong T;

        T.SetInit();

        T.pairCount = std::min<valueType>(7, S.pairCount + (count >= 2 ? 1 : 0));

        if (count >= 2) {
            T.state.second = Merge(S.state.first + (count - 2), S.state.second + count);
        } else {
            T.state.second = S.state.second + count;
        }

        T.state.first = S.state.first + count;

        // T.CheckHu();

        return T;
    }
};

int main() {
    auto __begin = std::chrono::steady_clock::now();

    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    std::cout.tie(nullptr);

    valueType N;

    std::cin >> N;

    ValueVector Bucket(N + 1, 0);

    for (valueType i = 0; i < 13; ++i) {
        valueType x, t;

        std::cin >> x >> t;

        ++Bucket[x];
    }

    std::map<Mahjong, valueType> ID;
    valueType size = 0;
    ValueMatrix Transfer(4000, ValueVector(5, -1));
    bitset Finish(4000, false);

    {
        std::queue<Mahjong> Q;

        Mahjong start;

        start.SetInit();

        ID[start] = ++size;
        Q.push(start);

        while (!Q.empty()) {
            Mahjong const state = Q.front();

            Q.pop();

            valueType const x = ID[state];

            {
                Mahjong temp = state;

                if (temp.CheckHu())
                    Finish[x] = true;

                assert(temp == state);
            }

            for (valueType count = 0; count <= 4; ++count) {
                Mahjong const next = state + count;

                // if (count == 0)
                //     assert(next == state);

                if (ID.count(next) > 0) {
                    Transfer[x][count] = ID[next];
                } else {
                    Transfer[x][count] = (ID[next] = ++size);
                    Q.push(next);
                }
            }
        }
    }

    std::cerr << "size = " << size << std::endl;
    std::cerr << "Time[1] : " << std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - __begin).count() << "[ms]" << std::endl;

    ValueMatrix FastC(5, ValueVector(5, 0));

    FastC[0][0] = 1;

    for (valueType i = 1; i <= 4; ++i) {
        FastC[i][0] = 1;

        for (valueType j = 1; j <= i; ++j)
            FastC[i][j] = FastC[i - 1][j] + FastC[i - 1][j - 1];
    }

    ValueMatrix F(4 * N + 1, ValueVector(size + 1, 0));

    F[0][1] = 1;

    for (valueType i = 1; i <= N; ++i) {
        ValueMatrix Next(4 * i + 1, ValueVector(size + 1, 0));

        for (valueType j = 0; j <= 4 * (i - 1); ++j) {
            valueType k = 1;
            // for (valueType k = 1; k <= size; ++k) {
            for (k = 1; k + 7 <= size; k += 4) {
                // for (valueType t = Bucket[i]; t <= 4; ++t)
                //     Inc(Next[j + t][Transfer[k][t]], mul(F[j][k], C(4 - Bucket[i], t - Bucket[i])));
                switch (Bucket[i]) {
                    case 0:
                        Inc(Next[j + 0][Transfer[k + 0][0]], mul(F[j][k + 0], FastC[4 - Bucket[i]][0 - Bucket[i]]));
                    case 1:
                        Inc(Next[j + 1][Transfer[k + 0][1]], mul(F[j][k + 0], FastC[4 - Bucket[i]][1 - Bucket[i]]));
                    case 2:
                        Inc(Next[j + 2][Transfer[k + 0][2]], mul(F[j][k + 0], FastC[4 - Bucket[i]][2 - Bucket[i]]));
                    case 3:
                        Inc(Next[j + 3][Transfer[k + 0][3]], mul(F[j][k + 0], FastC[4 - Bucket[i]][3 - Bucket[i]]));
                    case 4:
                        Inc(Next[j + 4][Transfer[k + 0][4]], mul(F[j][k + 0], FastC[4 - Bucket[i]][4 - Bucket[i]]));
                }

                switch (Bucket[i]) {
                    case 0:
                        Inc(Next[j + 0][Transfer[k + 1][0]], mul(F[j][k + 1], FastC[4 - Bucket[i]][0 - Bucket[i]]));
                    case 1:
                        Inc(Next[j + 1][Transfer[k + 1][1]], mul(F[j][k + 1], FastC[4 - Bucket[i]][1 - Bucket[i]]));
                    case 2:
                        Inc(Next[j + 2][Transfer[k + 1][2]], mul(F[j][k + 1], FastC[4 - Bucket[i]][2 - Bucket[i]]));
                    case 3:
                        Inc(Next[j + 3][Transfer[k + 1][3]], mul(F[j][k + 1], FastC[4 - Bucket[i]][3 - Bucket[i]]));
                    case 4:
                        Inc(Next[j + 4][Transfer[k + 1][4]], mul(F[j][k + 1], FastC[4 - Bucket[i]][4 - Bucket[i]]));
                }

                switch (Bucket[i]) {
                    case 0:
                        Inc(Next[j + 0][Transfer[k + 2][0]], mul(F[j][k + 2], FastC[4 - Bucket[i]][0 - Bucket[i]]));
                    case 1:
                        Inc(Next[j + 1][Transfer[k + 2][1]], mul(F[j][k + 2], FastC[4 - Bucket[i]][1 - Bucket[i]]));
                    case 2:
                        Inc(Next[j + 2][Transfer[k + 2][2]], mul(F[j][k + 2], FastC[4 - Bucket[i]][2 - Bucket[i]]));
                    case 3:
                        Inc(Next[j + 3][Transfer[k + 2][3]], mul(F[j][k + 2], FastC[4 - Bucket[i]][3 - Bucket[i]]));
                    case 4:
                        Inc(Next[j + 4][Transfer[k + 2][4]], mul(F[j][k + 2], FastC[4 - Bucket[i]][4 - Bucket[i]]));
                }

                switch (Bucket[i]) {
                    case 0:
                        Inc(Next[j + 0][Transfer[k + 3][0]], mul(F[j][k + 3], FastC[4 - Bucket[i]][0 - Bucket[i]]));
                    case 1:
                        Inc(Next[j + 1][Transfer[k + 3][1]], mul(F[j][k + 3], FastC[4 - Bucket[i]][1 - Bucket[i]]));
                    case 2:
                        Inc(Next[j + 2][Transfer[k + 3][2]], mul(F[j][k + 3], FastC[4 - Bucket[i]][2 - Bucket[i]]));
                    case 3:
                        Inc(Next[j + 3][Transfer[k + 3][3]], mul(F[j][k + 3], FastC[4 - Bucket[i]][3 - Bucket[i]]));
                    case 4:
                        Inc(Next[j + 4][Transfer[k + 3][4]], mul(F[j][k + 3], FastC[4 - Bucket[i]][4 - Bucket[i]]));
                }
            }

            while (k <= size) {
                switch (Bucket[i]) {
                    case 0:
                        Inc(Next[j + 0][Transfer[k][0]], mul(F[j][k], FastC[4 - Bucket[i]][0 - Bucket[i]]));
                    case 1:
                        Inc(Next[j + 1][Transfer[k][1]], mul(F[j][k], FastC[4 - Bucket[i]][1 - Bucket[i]]));
                    case 2:
                        Inc(Next[j + 2][Transfer[k][2]], mul(F[j][k], FastC[4 - Bucket[i]][2 - Bucket[i]]));
                    case 3:
                        Inc(Next[j + 3][Transfer[k][3]], mul(F[j][k], FastC[4 - Bucket[i]][3 - Bucket[i]]));
                    case 4:
                        Inc(Next[j + 4][Transfer[k][4]], mul(F[j][k], FastC[4 - Bucket[i]][4 - Bucket[i]]));
                }

                ++k;
            }
            // }
        }

        F.swap(Next);
    }

    std::cerr << "Time[2] : " << std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - __begin).count() << "[ms]" << std::endl;

    BinomialCoefficient const C(4 * N + 5);

    valueType ans = 0;

    for (valueType i = 13; i <= 4 * N; ++i) {
        valueType sumA = 0, sumB = 0;

        for (valueType j = 1; j <= size; ++j) {
            Inc(sumA, F[i][j]);

            if (!Finish[j])
                Inc(sumB, F[i][j]);
        }

        Inc(ans, mul(sumB, pow(sumA, MOD - 2)));
    }

    // Mul(ans, pow(C.Fact(4 * N - 13), MOD - 2));

    std::cout << ans << std::endl;

    return 0;
}
posted @ 2024-02-03 07:15  User-Unauthorized  阅读(157)  评论(0编辑  收藏  举报