「解题报告」CF708E Student's Camp

感觉 这篇题解 的做法很强啊,贺一下。

连通:考虑将每一种情况对应一条路径。钦定这条路径为能往下则往下,不能往下就向左或向右走到第一个能往下的位置然后往下。

这样只考虑每一种路径,再对应的计算路径相应的情况的概率和。这个是容易计算的,而路径需要记录的状态少了一维,于是就可以 \(O(nm)\) 的解决了。

设三个 DP 数组表示向右走,向左走和向下走的概率和,然后不难转移。

#include <bits/stdc++.h>
using namespace std;
const int MAXN = 1505, MAXK = 100005, P = 1000000007;
int qpow(int a, int b) {
    int ans = 1;
    while (b) {
        if (b & 1) ans = 1ll * ans * a % P;
        a = 1ll * a * a % P;
        b >>= 1;
    }
    return ans;
}
int n, m, a, b, p, k;
int fac[MAXK], inv[MAXK];
int C(int n, int m) {
    if (n < 0 || m < 0 || n < m) return 0;
    return 1ll * fac[n] * inv[m] % P * inv[n - m] % P;
}
int f[MAXK], g[MAXK];
int fd[MAXN][MAXN], fr[MAXN][MAXN], fl[MAXN][MAXN];
void add(int &a, long long b) {
    a = (a + b) % P;
}
int main() {
    scanf("%d%d%d%d%d", &n, &m, &a, &b, &k);
    p = 1ll * a * qpow(b, P - 2) % P;
    fac[0] = 1;
    for (int i = 1; i <= k; i++)
        fac[i] = 1ll * fac[i - 1] * i % P;
    inv[k] = qpow(fac[k], P - 2);
    for (int i = k; i >= 1; i--)
        inv[i - 1] = 1ll * inv[i] * i % P;
    assert(inv[0] == 1);
    for (int i = 0; i <= k; i++) {
        f[i] = 1ll * qpow(p, i) * qpow(1 - p + P, k - i) % P * C(k, i) % P;
    }
    g[0] = f[0];
    for (int i = 1; i <= max(k, m); i++) {
        g[i] = (g[i - 1] + f[i]) % P;
    }
    fr[0][1] = 1;
    for (int i = 0; i < n; i++) {
        for (int j = 1; j <= m; j++) {
            add(fr[i][j + 1], fr[i][j]);
            add(fr[i + 1][j + 1], 1ll * fr[i][j] * f[j - 1] % P * (i == 0 ? 1 : g[m - j]));
            add(fd[i + 1][j], 1ll * fr[i][j] * f[j - 1] % P * 
                (i == 0 ? 1 : g[m - j]) % P * g[m - j]);
            // printf("fr[%d][%d]=%d\n", i, j, fr[i][j]);
        }
        for (int j = m; j >= 1; j--) {
            add(fl[i][j - 1], fl[i][j]);
            add(fl[i + 1][j - 1], 1ll * fl[i][j] * f[m - j] % P * (i == 0 ? 1 : g[j - 1]));
            add(fd[i + 1][j], 1ll * fl[i][j] * f[m - j] % P * 
                (i == 0 ? 1 : g[j - 1]) % P * g[j - 1]);
            // printf("fl[%d][%d]=%d\n", i, j, fl[i][j]);
        }
        for (int j = 1; j <= m; j++) {
            add(fl[i + 1][j - 1], 1ll * fd[i][j] * g[m - j]);
            add(fr[i + 1][j + 1], 1ll * fd[i][j] * g[j - 1]);
            add(fd[i + 1][j], 1ll * fd[i][j] * g[j - 1] % P * g[m - j]);
            // printf("fd[%d][%d]=%d\n", i, j, fd[i][j]);
        }
    }
    int ans = 0;
    for (int i = 1; i <= m; i++) {
        add(ans, fd[n][i]);
    }
    printf("%d\n", ans);
    return 0;
}
posted @ 2023-04-23 15:10  APJifengc  阅读(42)  评论(0编辑  收藏  举报