「解题报告」ARC136F Flip Cells
感觉 AtCoder 上高于铜的难度就完全是随机了,这题完全是金吧,怎么可能只有铜啊,我快写自闭了。
以下记 \(n=hw\)。
我们设三个 DP 数组:
\(f_i\) 为从初始状态经过 \(i\) 步之后第一次达到终止状态的概率;
\(g_i\) 为从终止状态经过 \(i\) 步之后达到终止状态的概率;
\(h_i\) 为从初始状态经过 \(i\) 步之后达到终止状态的概率。
那么如果把三个 DP 数组都写成普通生成函数 \(F(x), G(x), H(x)\),那么有显然的关系 \(F(x)G(x) = H(x)\)。
可以发现 \(F(x)\) 是一个概率生成函数,我们要求的步数期望值就是 \(F'(1)\),而根据上述关系 \(F(x)=\frac{H(x)}{G(x)}\),这样我们就成功的去掉了第一次的限制。
那么接下来我们考虑去求 \(G(x),H(x)\)。
发现 \(G(x)\) 和 \(H(x)\) 是类似的,只有起始状态的区别,下面以 \(G(x)\) 的求法为例子。
首先发现翻转格子是有先后顺序的,也就是说这应当是一个 EGF,那么我们就先求 EGF。
假如说我们现在有一个终止状态,我们要去求到达该终止状态的生成函数 \(G(x)\),那么我们可以根据现在的状态与这个终止状态,得出哪些格子需要被翻转奇数次,哪些格子需要被翻转偶数次。那么我们可以用一个生成函数来表示每一个格子,总的概率就是所有格子的生成函数的乘积。
记翻转奇数次格子的生成函数为 \(E_1(x)\),偶数次的为 \(E_0(x)\),不难发现我们要求的就是 \([x^i]E_c(x) = \left(\frac{1}{n}\right)^i[i \bmod 2 = c]\),根据经验得出:
那么假如当前终止状态需要 \(i\) 个格子翻转奇数次,那么它的生成函数就是 \(E_1^i(x)E_0^{n - i}(x)\)。
对于所有的终止状态来说,我们可以处理出一个 DP 数组 \(gc_i\) 表示有 \(i\) 个格子需要翻转奇数次的终止状态有多少。
我们对每一行求出,然后再卷积起来就是最终的 DP 数组。求单独一行就枚举这一行翻转多少个 \(0\),根据初始状态和最终状态计算出需要翻转多少个 \(1\),组合数计算方案即可。
这样我们就能得到 \(G(x)\) 的 EGF 形式了:
但是我们需要的是 OGF,不是 EGF,我们还需要把 EGF 转回 OGF。可以发现,\(G(x)\) 计算出来之后是若干 \(e^{\Large\frac{ax}{n}}\) 的线性组合,而 \(e^{ax}\) 对应的 OGF 就是 \(\frac{1}{1-ax}\),所以假如我们能够将原生成函数写成 \(\displaystyle G_{\exp}(x) =\sum_{|i| \le n} a_ie^{\Large \frac{ix}{n}}\) 的形式,那么它对应的 OGF 就是 \(\displaystyle G(x) = \sum_{|i| \le n} \frac{a_i}{1 - \frac{ix}{n}}\)。
那么直接开始拆:
其中 \(\displaystyle a_k = [x^k] \sum_{i=0}^n \frac{gc_i}{2^n} (x-1)^i (x+1)^{n-i}\)。
我们尝试直接计算出后面这个多项式,套路的换元 \(y = x + 1\) 消掉一个幂,然后直接计算出关于 \(y\) 的多项式,再换回去拆开即可。
这样我们就能够得到 \(G(x)\) 了。
但是当 \(i = n\) 时,这个式子中存在 \(\frac{1}{1-x}\),代入 \(x=1\) 会使得它无意义。
找回上面的式子 \(F(x) = \frac{H(x)}{G(x)}\),为了把 \(\frac{1}{1-x}\) 消掉,我们可以上下同时乘一个 \((1-x)\),计算 \(F(x) = \frac{(1-x)H(x)}{(1-x)G(x)}\)。
这样对于 \(i=n\) 的情况,\(\frac{1-x}{1-x} = 1, \left(\frac{1-x}{1-x}\right)' = 0\);
对于 \(i \ne n\) 的情况,设 \(k = \frac{2i-n}{n}\),那么 \(\frac{1-x}{1-kx} = 0, \left(\frac{1-x}{1-kx}\right)' = \frac{-(1-kx)+(1-x)k}{(1-kx)^2}=\frac{k-1}{(1-kx)^2}\),代入 \(x=1\) 得到 \(\frac{1}{k-1}\)。
那么我们就可以计算答案了。
总复杂度 \(O(n^2) = O(h^2w^2)\)。
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 2505, P = 998244353;
int h, w, n;
int a[MAXN], b[MAXN];
int C[MAXN][MAXN];
char s[MAXN];
int gc[MAXN], hc[MAXN];
int G[MAXN], H[MAXN];
int pow2[MAXN];
int qpow(int a, int b) {
int ans = 1;
while (b) {
if (b & 1) ans = 1ll * ans * a % P;
a = 1ll * a * a % P;
b >>= 1;
}
return ans;
}
void getCount(int a[], int b[], int f[]) {
static int tmp1[MAXN], tmp2[MAXN];
f[0] = 1;
for (int i = 1; i <= h; i++) {
memset(tmp1, 0, sizeof tmp1);
for (int p = 0; p <= a[i]; p++) {
// p: 选 1 的个数 q: 选 0 的个数
int q = b[i] - (a[i] - p);
if (q >= 0 && q <= w - a[i]) {
tmp1[p + q] = (tmp1[p + q] + 1ll * C[a[i]][p] * C[w - a[i]][q]) % P;
}
}
memcpy(tmp2, f, sizeof tmp2);
memset(f, 0, sizeof tmp2);
for (int i = 0; i <= w; i++) {
for (int j = 0; i + j <= n; j++) {
f[i + j] = (f[i + j] + 1ll * tmp1[i] * tmp2[j]) % P;
}
}
}
}
void getPoly(int f[], int c[]) {
static int tmp[MAXN];
memset(tmp, 0, sizeof tmp);
int inv = qpow(pow2[n], P - 2);
for (int k = 0; k <= n; k++) {
for (int i = n - k; i <= n; i++) {
tmp[k] = (tmp[k] + 1ll * c[i] * inv % P * C[i][k - n + i] % P *
(((n - k) & 1) ? P - 1ll : 1ll) % P * pow2[n - k]) % P;
}
}
for (int i = 0; i <= n; i++) {
for (int j = i; j <= n; j++) {
f[i] = (f[i] + 1ll * tmp[j] * C[j][i]) % P;
}
}
}
int main() {
scanf("%d%d", &h, &w);
n = h * w;
C[0][0] = 1;
for (int i = 1; i <= n; i++) {
C[i][0] = 1;
for (int j = 1; j <= i; j++) {
C[i][j] = (C[i - 1][j - 1] + C[i - 1][j]) % P;
}
}
pow2[0] = 1;
for (int i = 1; i <= n; i++) {
pow2[i] = 2ll * pow2[i - 1] % P;
}
for (int i = 1; i <= h; i++) {
scanf("%s", s + 1);
for (int j = 1; j <= w; j++) {
if (s[j] == '1') a[i]++;
}
}
for (int i = 1; i <= h; i++) {
scanf("%d", &b[i]);
}
getCount(b, b, gc);
getCount(a, b, hc);
getPoly(G, gc);
getPoly(H, hc);
int g1 = G[n], h1 = H[n], gd1 = 0, hd1 = 0;
for (int i = 0; i < n; i++) {
int inv = qpow((1ll * (2 * i - n) * qpow(n, P - 2) % P - 1 + P) % P, P - 2);
gd1 = (gd1 + 1ll * G[i] * inv) % P;
hd1 = (hd1 + 1ll * H[i] * inv) % P;
}
int ans = (1ll * hd1 * g1 % P - 1ll * h1 * gd1 % P + P) * qpow(1ll * g1 * g1 % P, P - 2) % P;
printf("%d\n", ans);
return 0;
}