Luogu5664 Emiya 家今天的饭

https://www.luogu.com.cn/problem/P5664

\(DP\)

首先,根据题意,每一行至多取一个

直接求状态数太多,正难则反

我们用全部方案数减去不符合要求的方案数(注意除去全不取情况)

这里不符合要求的方案即食材数超过一半的那些方案

枚举哪一列食材数超过一半,其他任意取,记录非当前列取了多少次,当前列取了多少次

\(dp_{i,j,k}\)表示前\(i\)行,非当前列取了\(j\)次,当前列取了\(k\)次的方案数

\(S_i\)表示第\(i\)行的所有数之和

注:以下式子中数据范围略去

\[dp_{i,j,k}=dp_{i-1,j,k}+dp_{i-1,j-1,k}\times (S_i-a_{i,j})+dp_{i-1,j,k-1}\times a_{i,j}\\ ans=\sum_{j<k} dp_{n,j,k} \]

可以滚动掉一维数组

时间复杂度:\(O(n^3m)\)

\(Code:\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#define N 105
#define M 2005
#define p 998244353
#define ll long long 
using namespace std;
int n,m,sm[N],a[N][M],dp[2][N][N];
int ans,del;
void solve1()
{
    ans=1;
    for (int i=1;i<=n;i++)
        ans=(ll)ans*(sm[i]+1)%p;
}
void pl(int &x,int y)
{
    x=(x+y)%p;
}
void solve2()
{
    for (int i=1;i<=m;i++)
    {
        int cur=0;
        memset(dp,0,sizeof(dp));
        dp[cur][0][0]=1;
        for (int j=1;j<=n;j++)
        {
            cur^=1;
            for (int k=0;k<=j;k++)
                for (int t=0;t<=j-k;t++)
                {
                    dp[cur][k][t]=dp[cur^1][k][t];
                    if (k)
                        pl(dp[cur][k][t],(ll)dp[cur^1][k-1][t]*(sm[j]-a[j][i])%p);
                    if (t)
                        pl(dp[cur][k][t],(ll)dp[cur^1][k][t-1]*a[j][i]%p);
                }
        }
        for (int j=0;j<=n;j++)
            for (int k=j+1;k<=n-j;k++)
                pl(del,dp[cur][j][k]);
    }
}
int main()
{
    scanf("%d%d",&n,&m);
    for (int i=1;i<=n;i++)
        for (int j=1;j<=m;j++)
            scanf("%d",&a[i][j]),sm[i]=(sm[i]+a[i][j])%p;
    solve1();
    solve2();
    ans-=del+1;
    ans=(ans%p+p)%p;
    printf("%d\n",ans);
    return 0;
}

观察\(dp\)方程,我们只关心\(j,k\)的差,因此\(dp\)方程中仅记录\(j,k\)之差即可

\(t=k-j+n\)(防负数)

\[dp_{i,t}=dp_{i-1,t}+dp_{i-1,t+1}\times (S_i-a_{i,j})+dp_{i-1,t-1}\times a_{i,j}\\ ans=\sum_{t>n} dp_{n,t} \]

同样可以滚动掉一维数组

时间复杂度:\(O(n^2m)\)

\(C++ Code:\)

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#define N 105
#define M 2005
#define p 998244353
#define ll long long 
using namespace std;
int n,m,sm[N],a[N][M],dp[2][N << 1];
int ans,del;
int read()
{
    int s=0;
    char c=getchar();
    while (c<'0' || c>'9')
        c=getchar();
    while ('0'<=c && c<='9')
    {
        s=s*10+c-'0';
        c=getchar();
    }
    return s;
}
void solve1()
{
    ans=1;
    for (int i=1;i<=n;i++)
        ans=(ll)ans*(sm[i]+1)%p;
}
void pl(int &x,int y)
{
    x=(x+y)%p;
}
void solve2()
{
    for (int i=1;i<=m;i++)
    {
        int cur=0;
        memset(dp,0,sizeof(dp));
        dp[cur][n]=1;
        for (int j=1;j<=n;j++)
        {
            cur^=1;
            for (int k=n-j;k<=n+j;k++)
            {
                dp[cur][k]=dp[cur^1][k];
                if (k<(n << 1))
                    pl(dp[cur][k],(ll)dp[cur^1][k+1]*(sm[j]-a[j][i])%p);
                if (k)
                    pl(dp[cur][k],(ll)dp[cur^1][k-1]*a[j][i]%p);
            }
        }
        for (int k=n+1;k<=(n << 1);k++)
            pl(del,dp[cur][k]);
    }
}
int main()
{
    n=read(),m=read();
    for (int i=1;i<=n;i++)
        for (int j=1;j<=m;j++)
            a[i][j]=read(),sm[i]=(sm[i]+a[i][j])%p;
    solve1();
    solve2();
    ans-=del+1;
    ans=(ans%p+p)%p;
    printf("%d\n",ans);
    return 0;
}
posted @ 2020-09-04 20:49  GK0328  阅读(134)  评论(0编辑  收藏  举报