luogu P5664 Emiya 家今天的饭 容斥+dp

参考博客
https://ksmeow.moe/meal-csps219-sol/

//首先考虑列的限制,发现若有不合法的列,
//则必然有且只有一列是不合法的:因为不可能有不同的两列数量都超过总数的一半

//合法方案的容斥计算:每行选不超过一个的方案数-每行选不超过一个且某一列选了超过一半的方案数
//处理不合法的方案,也就是超过一半的 
//f[i,j,k]表示对col这一列,前i行在col列选了j个,在其他列选了k个,s[i]为第i行的总和,那么转移:
//f[i,j,k]=f[i-1,j,k]+a[i,col]*f[i-1,j-1,k]+(s[i]-a[i,col])*f[i-1,j,k-1]
//	  在这一行的这一列不选 	在当前位置有几种选法	在这一列的其他列选

//g[i,j]为前i行共选了j个数的方案数,那么
//g[i,j]=g[i-1,j]+s[i]*g[i-1,j]

//对于f[i,j,k],在状态转移的过程中,实际并不关系j和k的实际大小,只关心相对大小
//那么定义为f[i,j],表示前i行,当前列的数比其他列的数多了j个 ,那么状态转移方程 
//f[i,j]=f[i-1,j]+a*[i,col]*f[i-1,j-1]+(s[i]-a[i,col])*f[i-1,j+1]
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=2005;
const int mod=998244353;
int n,m,a[105][N*2],sum[105][N*2];
ll f[105][N*2],g[105][N*2];
int main()
{
    cin >> n >> m;
    //输入数据+处理前缀和 
    for(int i = 1; i<=n; i++)
        for(int j = 1; j<=m; j++)
        {
            scanf("%d",&a[i][j]);
            sum[i][0] = (sum[i][0]+a[i][j])%mod;
        }
    //预处理出 除去某一个格子之后,当前行剩下的数值之和 
    for(int i = 1; i<=n; i++)
        for(int j = 1; j<=m; j++)
            sum[i][j] = (sum[i][0]-a[i][j]+mod)%mod;
    ll ans = 0;
    //合法方案容斥计算:每行选不超过一个的方案数-每行选不超过一个且某一列选了超过一半的方案数 
    //那么定义为f[i,j],表示前i行,当前列的数比其他列的数多了j个 ,那么状态转移方程 
	//f[i,j]=f[i-1,j]+a*[i,col]*f[i-1,j-1]+(s[i]-a[i,col])*f[i-1,j+1]
    for(int col = 1; col<=m; col++)
    {
        memset(f,0,sizeof(f)); 
        f[0][n]=1;//前0行,当前列的数比其他列的数多n个的方案只有一个,是一个不合法的方案 
        for(int i = 1; i<=n; i++)
        	//因为差值可能为负数,所以搞一个为n的偏移量 
        	//n-i表示差值最小时,也就是前面都选,而且不在这一列
			//n+i表示差值最大时,也就是前面都选,而且都在这一列 
            for(int j = n-i; j<=n+i; j++) 
                f[i][j] = (f[i-1][j]+f[i-1][j-1]*a[i][col]%mod+f[i-1][j+1]*sum[i][col]%mod)%mod;
        for(int j = 1; j<=n; j++)
            ans = (ans+f[n][n+j])%mod;
    }
    //所有方案数
	//什么都没做,有一种方案 
	//g[i,j]为前i行每行最多选1个,共选了j个数的方案数,那么
    g[0][0] = 1;
    for(int i = 1; i<=n; i++)
        for(int j = 0; j<=n; j++) 
            g[i][j] = (g[i-1][j]+(j>0?g[i-1][j-1]*sum[i][0]%mod:0))%mod;
    int ans1=0;
    for(int j = 1; j<=n; j++)
    	ans1=(ans1+g[n][j])%mod;
    cout << (ans1-ans+mod)%mod << endl;
    return 0;
} 
posted @ 2020-03-29 17:51  晴屿  阅读(85)  评论(0编辑  收藏  举报