【NOI2017】泳池
一道很棒的神仙题!
设 \(calc(k)\) 表示区域大小不超过 \(k\) 的概率,那么答案就是 \(calc(k) - calc(k - 1)\)
有一些列的第一个位置就是障碍,它们会把泳池划分为不相关的几个部分。可以从这个角度进行 DP
设 \(f_n\) 表示,连续 \(n\) 列,最后一列的第一个位置是障碍,且最大合法矩形面积不超过 \(k\) 的概率。
设 \(g_n\) 表示,连续 \(n\) 列,每一列的第一个位置都不是障碍,且最大合法矩形面积不超过 \(k\) 的概率。
枚举上一个障碍点在哪,我们就得到转移:
\[f_n = (1-q) \sum_{i = 0} ^ {\min\{n - 1, k\}} f_{n - i - 1} g_{i}
\]
边界条件是 \(f_0 = 1\),我们要求的答案就是 \(\frac{f_{n+1}}{1-q}\)。
现在我们要求出 \(g\) 数组。
设 \(dp_{i,j}\) 表示连续的 \(i\) 列,出现的最低障碍是 \(j+1\), 且最大合法矩形面积不超过 \(k\) 的概率,则 \(g_i = \sum_{j = 1} ^ {\lfloor \frac{k}{i} \rfloor} dp_{i}{j}\)。
我们可以枚举第一个出现的最低障碍进行转移,就得到:
\[dp_{i,j} = (1-q) q^j \sum_{k = 1} ^ i (\sum_{t \geq j + 1} dp_{k - 1, t}) (\sum_{t \geq j} dp_{i - k,j})
\]
因为 \(ij \leq k\), 所以状态数只有 \(O(klogk)\), 加个后缀和优化就可以 \(O(k^2logk)\) 处理了。
现在的问题是,怎样求得 \(f\) 的第 \(n\) 项。
可以发现上面的式子是一个齐次线性递推,直接拖个板子就可以AC啦。
事实上是 UOJ 上被 hack 成 97 分了。
原因是当线性递推式的阶数为 \(1\) 时,需要对一个一次多项式取模。快速幂的时候初始多项式是一次的,如果不取模就挂了。
#pragma GCC optimize("2,Ofast,inline")
#include<bits/stdc++.h>
#define fi first
#define se second
#define mp make_pair
#define pb push_back
#define LL long long
#define pii pair<int, int>
using namespace std;
const int N = 2005;
const int mod = 998244353;
template <typename T> T read(T &x) {
int f = 0;
register char c = getchar();
while (c > '9' || c < '0') f |= (c == '-'), c = getchar();
for (x = 0; c >= '0' && c <= '9'; c = getchar())
x = (x << 3) + (x << 1) + (c ^ 48);
if (f) x = -x;
return x;
}
namespace Comb {
const int Maxn = 1e6 + 10;
int fac[Maxn], fav[Maxn], inv[Maxn];
void comb_init() {
fac[0] = fav[0] = 1;
inv[1] = fac[1] = fav[1] = 1;
for (int i = 2; i < Maxn; ++i) {
fac[i] = 1LL * fac[i - 1] * i % mod;
inv[i] = 1LL * -mod / i * inv[mod % i] % mod + mod;
fav[i] = 1LL * fav[i - 1] * inv[i] % mod;
}
}
inline int C(int x, int y) {
if (x < y || y < 0) return 0;
return 1LL * fac[x] * fav[y] % mod * fav[x - y] % mod;
}
inline int Qpow(int x, int p) {
int ans = 1;
for (; p; p >>= 1) {
if (p & 1) ans = 1LL * ans * x % mod;
x = 1LL * x * x % mod;
}
return ans;
}
inline int Inv(int x) {
return Qpow(x, mod - 2);
}
inline void upd(int &x, int y) {
(x += y) >= mod ? x -= mod : 0;
}
inline int add(int x, int y) {
return (x += y) >= mod ? x - mod : x;
}
inline int dec(int x, int y) {
return (x -= y) < 0 ? x + mod : x;
}
}
using namespace Comb;
namespace Linear {
int n, k;
int a[N], h[N], p[N];
int b[N], c[N];
void module(int *x) {
for (int i = k * 2; i >= k; --i) {
int tmp = 1LL * Inv(p[k]) * x[i] % mod;
for (int j = 0; j <= k; ++j) {
x[i - j] = dec(x[i - j], 1LL * p[k - j] * tmp % mod);
}
}
}
void mul(int *x, int *y, int *z) {
static int res[N];
for (int i = 0; i <= k * 2; ++i) res[i] = 0;
for (int i = 0; i < k; ++i) {
for (int j = 0; j < k; ++j) {
upd(res[i + j], 1LL * x[i] * y[j] % mod);
}
}
module(res);
for (int i = 0; i < k; ++i) z[i] = res[i];
}
void poly_pow(int p) {
while (p) {
if (p & 1) mul(b, c, c);
mul(b, b, b);
p >>= 1;
}
}
int solve() {
if (n <= k * 2) return h[n];
memset(b, 0, sizeof b);
memset(c, 0, sizeof c);
memset(p, 0, sizeof p);
p[k] = 1;
for (int i = 0; i < k; ++i) p[i] = dec(0, a[k - i]);
b[1] = 1; c[0] = 1;
if (k == 1) module(b);
poly_pow(n - k);
int ans = 0;
for (int i = 0; i < k; ++i)
upd(ans, 1LL * h[i + k] * c[i] % mod);
return ans;
}
}
int n, m, x, y, q;
int f[N], g[N], G[N], sdp[N][N], dp[N][N];
int solve() {
memset(f, 0, sizeof f);
memset(g, 0, sizeof g);
memset(dp, 0, sizeof dp);
memset(sdp, 0, sizeof sdp);
dp[0][0] = 1;
for (int i = 1; i <= m; ++i) {
for (int j = m / i; j >= 0; --j) {
for (int k = 1; k <= i; ++k) {
int s1 = 0, s2 = 0;
if (k == 1) s1 = 1;
else s1 = sdp[k - 1][j + 1];
if (i == k) s2 = 1;
else s2 = sdp[i - k][j];
upd(dp[i][j], 1LL * s1 * s2 % mod);
}
dp[i][j] = 1LL * dp[i][j] * Qpow(q, j) % mod;
dp[i][j] = 1LL * dp[i][j] * dec(1, q) % mod;
sdp[i][j] = add(sdp[i][j + 1], dp[i][j]);
}
}
g[1] = dec(1, q);
for (int i = 2; i <= m + 1; ++i) {
g[i] = 1LL * sdp[i - 1][1] * dec(1, q) % mod;
}
f[0] = 1;
for (int i = 1; i <= m * 2 + 2; ++i) {
for (int j = 1; j <= m + 1 && j <= i; ++j) {
upd(f[i], 1LL * f[i - j] * g[j] % mod);
}
}
Linear :: n = n + 1;
Linear :: k = m + 1;
for (int i = 1; i <= m + 1; ++i) {
Linear :: a[i] = g[i];
}
for (int i = 1; i <= m * 2 + 2; ++i) {
Linear :: h[i] = f[i];
}
return 1LL * Linear :: solve() * Inv(dec(1, q)) % mod;
}
int main() {
comb_init();
read(n); read(m); read(x); read(y);
q = 1LL * x * Inv(y) % mod;
int ans1 = solve();
--m;
int ans2 = solve();
// cout << ans1 << ' ' << ans2 << endl;
cout << dec(ans1, ans2) << endl;
return 0;
}