[AGC021F] Trinity 题解

Tag

动态规划,NTT。

Preface

被迫营业。

Description

在一个 \(n\times m\) 的黑白格中,设长度为 \(n\) 的数列 \(A\),长度为 \(m\) 的数列 \(B\)\(C\),分别表示一下规则:

  • \(A_i\) 表示第 \(i\) 行第一个出现的黑格子的位置,如果没有黑格子则为 \(m+1\)
  • \(B_i\) 表示第 \(i\) 列第一个出现的黑格子的位置,如果没有黑格子则为 \(n+1\)
  • \(C_i\) 表示第 \(i\) 列最后一个出现黑格子的位置,如果没有黑格子则为 \(0\)

求最后有多少种不同的 \((A,B,C)\) 三元组。

\(\texttt{data range:} n\leq 8\times 10^3,m\leq 200\)

Solution

不难想到,如果我们限定只有 \(k\) 行有黑色格子出现,显然这 \(k\) 行在任意位置都可以,结果都是一样的,对于最后的方案数我们只用乘上一个组合数就可以了,所以我们现在只考虑有 \(k\) 行有黑格子的情况。

显然可以设计 dp 来计数这个过程。我们设 \(f_{i,j}\) 表示有 \(i\) 行,考虑到第 \(j\) 列的情况数。

一个比较显然的转移就是:

\[f_{i,j} \xrightarrow{1+i+\binom{i}{2}} f_{i,j+1} \]

表示的是,要么不取,要么取一个,要么取两个,由于我们计数的是 \((A,B,C)\) 的情况数,所以不用考虑中间是什么情况。

但是这个时候我们需要好好思考一下怎么从某一行转移过来。

我们讨论,如何从 \(f_{k,j}\to f_{i,j+1}\),其中 \(k < i\)

我们需要注意,我们不是往某一行后面加某个数,而是从中间插入若干的数,由于我们是根据层递推的,所以此时插入的 \(A\) 的值应该为 \(j+1\),然后我们需要枚举要从多少行递推过来的。

我们接下来需要考虑的是 \(B\)\(C\) 的情况,这里的考虑也将是这个题唯一的难点所在。

看起来这里可以直接进行一个组合数的枚举,就是:

\[f_{k,j-1} \xrightarrow{\binom{i}{k}}f_{i,j} \]

这样是会数漏的,因为此时我们只考虑了插入的情况,但是我们压根就没有考虑到如果说我们插入的这一行与之前的某一列重合了就会很尴尬,下面设黑色为 \(1\),白色为 \(0\)

10
01
01
10
01

这样是我们插入的一种情况,问题来了。

11
01
01
10
01

这种情况是和上面的情况在我们的组合数中是同一种情况,但是又不同的表现,所以我们肯定不能单纯的只选出 \(k\) 个来进行插入,我们还需要考虑原有的黑色的情况。

所以我们需要考虑 \(B_i\) 上面的白色格子的位置以及 \(C_i\) 下面的白色格子的位置,这样子的话就不会数漏了,因为一定会存在这两个位置,所以用来转移的组合数就是 \(\dbinom{i+2}{i-k+2}\),也就是多取了两个。

然后我们得到了一个最终的式子:

\[f_{i,j} = (1+i+\dbinom{i}{2})f_{i,j-1}+\sum_{k<i} \dbinom{i+2}{i-k+2}f_{k,j} \]

状态数 \(O(nm)\),转移 \(O(n)\),我们的时间复杂度为 \(O(n^2m)\),显然过不去,但是考虑到后面是个组合数,所以我们很容易就能想出用 NTT 优化,于是我们就做到了时间复杂度 \(O(nm\lg n)\) 和空间复杂度 \(O(nm)\)

Code

const int N = 8e3 + 10;
const int M = 2e2 + 10;
const int mod = 998244353, g = 3;;
using poly = vector<int>;
using ll = long long;

inline void chk(int &x) {x -= mod; x += x >> 31 & mod;}
inline int mll(int x, int y) {return (ll) x * y % mod;}
inline int add(int x, int y) {chk(x += y); return x ;}
inline int del(int x, int y) {return add(x, mod - y);}

inline int ksm(int x, int y) {
    int ret = 1;
    for(; y; y >>= 1, x = mll(x, x))
        if(y & 1) ret = mll(ret, x);
    return ret;
}

ostream &operator << (ostream &ou, const poly F) {
    FOR(i, 0, F.size() - 1) ou << F[i] << ' ';
    return ou;
}


int fc[N + M], fv[N + M], inv[N + M]; 
inline void pref(const int lim) {
    fc[0] = 1;
    FOR(i, 1, lim) fc[i] = mll(fc[i - 1], i);
    fv[lim] = ksm(fc[lim], mod - 2);
    ROF(i, lim, 1) fv[i - 1] = mll(fv[i], i);
    FOR(i, 1, lim) inv[i] = mll(fv[i], fc[i - 1]);
    return ;
}

int C(int n, int m) {
    if(n < m) return 0;
    return mll(fc[n], mll(fv[m], fv[n - m]));
}

int rev[N << 2];
inline int getrev(const int n) {
    int len = 1, tim = -1;
    while(len < n) len <<= 1, tim++;
    FOR(i, 0, len - 1) rev[i] = rev[i >> 1] >> 1 | ((i & 1) << tim);
    return len;
}

void DFT(poly &F, int n) {
    F.resize(n);
    FOR(i, 0, n - 1) if(i < rev[i]) swap(F[i], F[rev[i]]);
    for(int i = 1; i < n; i <<= 1) {
        int gn = ksm(g, (mod - 1) / (i << 1));
        for(int j = 0, g0 = 1, x, y; j < n; j += (i << 1), g0 = 1)
        for(int k = 0; k < i; k++, g0 = mll(gn, g0)) {
            x = F[j + k], y = mll(F[i + j + k], g0);
            F[j + k] = add(x, y);
            F[i + j + k] = del(x, y);
        }
    }
}
void IDFT(poly &F, int n) {
    DFT(F, n), reverse(F.begin() + 1, F.end());
    int iv = ksm(n, mod - 2);
    FOR(i, 0, n - 1) F[i] = mll(F[i], iv);
}

poly operator * (poly A, poly B) {
    int n = A.size() + B.size() - 1, len = getrev(n);
    DFT(A, len), DFT(B, len);
    FOR(i, 0, len - 1) A[i] = mll(A[i], B[i]);
    return IDFT(A, len), A.resize(n), A;
}

int dp[2][N];

inline void solve() {
    int n, m;
    cin >> n >> m;
    pref(n + m);
    FOR(i, 0, n) dp[1][i] = 1;
    FOR(j, 2, m) {
        int op = j & 1;
        poly F(n + 1), G(n + 1);
        FOR(i, 1, n) F[i] = fv[i + 2];
//        dbg(F);
        FOR(i, 0, n) G[i] = mll(dp[op ^ 1][i], fv[i]);
//        dbg(G);
        F = F * G;
//        dbg(F);
        FOR(i, 0, n) {
            dp[op][i] = mll(add(1 + i, C(i, 2)), dp[op ^ 1][i]);
            chk(dp[op][i] += mll(fc[i + 2], F[i]));
        }
    }
    int ans = 0;
    FOR(i, 0, n) chk(ans += mll(C(n, i), dp[m & 1][i]));
    cout << ans << '\n';
    return ;
}
posted @ 2022-02-27 21:12  Kamiya-Kina  阅读(31)  评论(0编辑  收藏  举报