CF708E Student's Camp 题解
题目大意
一个 \((n + 2) \times m\) 的网格。
除了第一行与最后一行,每一行都有 \(p\) 的概率消失,求 \(k\) 天后,网格始终保持联通的概率。
答案对 \(10^9 + 7\) 取模。
\(\text{Data Range:} 1 \leq n,m \leq 1.5 \times 10^3, k\leq 10^5\)。
不难发现最后每一行剩下的一定都会是一个连续的区间,而且每行之间相互独立。
自然设立状态 \(dp_{i,l,r}\) 表示当前是第 \(i\) 行,\(k\) 天后这一行剩下的区间为 \([l, r]\),第 \(0\) 行到第 \(i\) 行都联通的概率。
先考虑算变成区间 \([l,r]\) 的概率。
设 \(k\) 天内,消失 \(i\) 个格子的概率为 \(g_i = \binom{k}{i} p^i \times (1-p)^{k-i}\)。
所以剩下区间 \([l,r]\) 的概率为 \(g_{l-1} \times g_{m-r}\)。
当然还需要检查这段区间长度是否合法,因为最多消失 \(2k\) 个格子。
转移式子不难写出。
复杂度 \(nm^4\),显然不行,而且状态数都为 \(nm^2\)。
那考虑优化状态,去掉左端点的限制。
设 \(dp_{i,r}\) 表示第 \(i\) 行的剩余区间的右端点为 \(r\),第 \(0\) 行到第 \(i\) 行都联通的概率,转移的时候枚举 \(l\) 进行转移。
考虑一个和区间 \([l,r]\) 有交的区间 \([l_1,r_1]\)。
它满足 \(r_1 \geq l, l_1 \leq r\),那么容斥把这一部分算出来即可。
对于 \(r_1 < l\) 的部分,枚举 \(l\) 时计算 \(\sum_{j=1}^{l-1} dp_{i-1,j}\)。
对于 \(l_1 > r\) 的部分,因为网格是对称的,所以 \(f_{i,r}\) 也可以表示左端点为 \(m-r+1\) 且联通的概率,因此这一部分可以计算为 \(\sum_{j=1}^{m-r} f_{i-1,j}\)。
于是转移如下。
这东西长得就很一脸前缀和的样子,那就考虑前缀和优化。
令 \(s_{j}\) 表示 \(\sum_{k=1}^j f_{i-1,k}\)。
在令 \(p_i\) 和 \(q_i\) 依次为 \(g_i\) 与 \(g_i \times s_i\) 的前缀和,之后即可做到 \(O(1)\) 转移,总时间复杂度 \(o(nm)\)。
// 德丽莎你好可爱德丽莎你好可爱德丽莎你好可爱德丽莎你好可爱德丽莎你好可爱
// 德丽莎的可爱在于德丽莎很可爱,德丽莎为什么很可爱呢,这是因为德丽莎很可爱!
#include <bits/stdc++.h>
#define int long long
using namespace std;
inline int read() {
int x = 0, f = 1; char ch = getchar();
while( !isdigit(ch) ) { if(ch == '-') f = -1; ch = getchar(); }
while( isdigit(ch) ) { x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar(); }
return x * f;
}
const int N = 3e6, mod = 1e9 + 7;
int n, m, a, b, pt, pt2, k, ans, fac[N], ifac[N], dp[3005][3005], g[3005];
int power(int a,int b) { int ans = 1; while (b) { if (b & 1) ans = ans * a % mod; a = a * a % mod; b >>= 1;} return ans; }
int binom(int n,int m) { return fac[n] * ifac[m] % mod * ifac[n - m] % mod; }
int s[N], p[N], q[N];
signed main () {
n = read(), m = read(), a = read(), b = read(), k = read();
fac[0] = 1; for (int i = 1; i <= k; i++) fac[i] = fac[i - 1] * i % mod;
ifac[k] = power(fac[k], mod - 2); for (int i = k - 1; ~i; i--) ifac[i] = ifac[i + 1] * (i + 1) % mod;
pt = a * power(b, mod - 2) % mod; pt2 = 1 - pt; pt2 += mod; pt2 %= mod;
for (int i = 0; i <= k && i <= m; i++) g[i] = binom(k, i) * power(pt, i) % mod * power(pt2, k - i) % mod;
dp[0][m] = 1; p[0] = g[0];
for (int i = 1; i <= m; i++) p[i] = p[i - 1] + g[i], p[i] %= mod;
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= m; j++) s[j] = s[j - 1] + dp[i - 1][j], s[j] %= mod;
for (int j = 1; j <= m; j++) q[j] = q[j - 1] + g[j] * s[j] % mod, q[j] %= mod;
for (int r = 1; r <= m; r++) dp[i][r] = g[m - r] * ( (s[m] - s[m - r] + mod) % mod * p[r - 1] % mod - q[r - 1] + mod) % mod;
}
for (int i = 1; i <= m; i++) { ans += dp[n][i]; ans %= mod;}
printf("%lld\n", ans);
return 0;
}