P5664 [CSP-S2019] Emiya 家今天的饭

很久之前的题今天才做。

Solution

发现每种主要食材不能超过\(\left\lfloor \frac{k}{2} \right\rfloor\),那考虑不合法的情况,有也仅有一种主要食材会出现\(> \left\lfloor \frac{k}{2} \right\rfloor\)的情况。
于是,我们考虑如何求出总情况数和不合法情况数,然后再相减就是答案。
考虑枚举不合法的主要食材\(col\)。设\(dp[i][j][k]\)为考虑前\(i\)种烹饪方法,用第\(i\)种烹饪方法第\(col\)种主要食材做了\(j\)道菜,用\(i\)种烹饪方法做除第\(col\)种主要食材其他的菜做了\(k\)道。则

\[dp[i][j][k] = dp[i - 1][j][k] + dp[i - 1][j - 1][k] \cdot a[i][col] + dp[i - 1][j][k - 1] \cdot (s[i] - a[i][col]) \]

其中\(s[i] = \sum_{j = 1} ^ m a[i][j]\)
不合法答案即为\(\sum_{n > j > k} dp[n][j][k]\)
但是这样做是\(\mathcal{O}(m \cdot n^3)\),考虑到\(n \le 100,m \le 1000\),这样完全不行!
考虑到我们计算答案时并不关心\(j,k\)的具体数值,指看\(j > k\),所以我们将第二维变为\(j - k\)的值,即设\(dp[i][x]\)为用前\(i\)种烹饪方法,\(j - k = x\)的情况。
易得

\[dp[i][x] = dp[i - 1][x] + dp[i - 1][x - 1] \cdot a[i][col] + dp[i - 1][x + 1] \cdot (s[i] - a[i][col]) \]

答案为\(\sum_{x > 0} dp[i][x]\)。时间复杂度为\(\mathcal{O}(m \cdot n^2)\)
发现第二维可能为负,要平移处理。
这题的dp优化方法值得学习。
还有就是取模有点奇怪。

#include <bits/stdc++.h>
using namespace std;
# define int long long
const int N = 105,M = 2005;
const int mod = 998244353;
int n,m;
int f[N][N],dp[N][N << 1];
int a[N][M],s[N];
int Hash(int x)
{
    return x + n + 1;
}
signed main(void)
{
    scanf("%lld%lld",&n,&m);
    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= m; j++)
            scanf("%lld",&a[i][j]);
    for(int i = 1; i <= n; i++) 
        for(int j = 1; j <= m; j++) s[i] = (s[i] + a[i][j]) % mod;
    f[0][0] = 1;
    for(int i = 1; i <= n; i++)
    {
        for(int j = 0; j <= i; j++) f[i][j] = (f[i - 1][j] + (f[i - 1][j - 1] * (s[i] % mod))) % mod;
    }
    dp[0][Hash(0)] = 1;
    long long ans1 = 0,ans2 = 0;
    for(int col = 1; col <= m; col++)
    {
        memset(dp,0,sizeof(dp));
        dp[0][Hash(0)] = 1ll;
        for(int i = 1; i <= n; i++)
        {
            for(int j = Hash(-i); j <= Hash(i); j++)
            {
                dp[i][j] = (dp[i - 1][j] + dp[i - 1][j - 1] * (a[i][col] % mod) + dp[i - 1][j + 1] * ((s[i] - a[i][col] + mod) % mod)) % mod;
            }
        }
        for(int i = Hash(1); i <= Hash(n); i++) ans1 = (ans1 + dp[n][i]) % mod;
    }
    for(int i = 1; i <= n; i++) ans2 = (ans2 + f[n][i]) % mod;
    printf("%lld\n",(ans2 - ans1 + mod) % mod);

    return 0;
}
posted @ 2021-07-29 10:30  luyiming123  阅读(77)  评论(0编辑  收藏  举报