Emiya 家今天的饭「容斥+DP」
题目描述
Emiya 是个擅长做菜的高中生,他共掌握 \(n\) 种烹饪方法,且会使用 \(m\) 种主要食材做菜。为了方便叙述,我们对烹饪方法从 \(1∼n\) 编号,对主要食材从 \(1∼m\) 编号。
Emiya 做的每道菜都将使用恰好一种烹饪方法与恰好一种主要食材。更具体地, Emiya 会做 \(a_{i,j}\) 道不同的使用烹饪方法 \(i\) 和主要食材 \(j\) 的菜 \((1≤i≤n,1≤j≤m)\),这也意味着 Emiya 总共会做 \(\sum_{i=1}^{n}\sum_{i=1}^{m}a_{i,j}\) 道不同的菜。
Emiya 今天要准备一桌饭招待 Yazid 和 Rin 这对好朋友,然而三个人对菜的搭配 有不同的要求,更具体地,对于一种包含 \(k\) 道菜的搭配方案而言:
- Emiya 不会让大家饿肚子,所以将做至少一道菜,即 \(k≥1\)。
- Rin 希望品尝不同烹饪方法做出的菜,因此她要求每道菜的烹饪方法互不相同。
- Yazid 不希望品尝太多同一食材做出的菜,因此他要求每种主要食材至多在一半的菜(即 \(⌊k/2⌋\) 道菜)中被使用 这里的 \(⌊x⌋\) 为下取整函数,表示不超过 \(x\) 的最大整数。
这些要求难不倒 Emiya,但他想知道共有多少种不同的符合要求的搭配方案。两种方案不同,当且仅当存在至少一道菜在一种方案中出现,而不在另一种方案中出现。
Emiya 找到了你,请你帮他计算,你只需要告诉他符合所有要求的搭配方案数对质数 \(998,244,353\) 取模的结果。
输入格式
第 \(1\) 行两个用单个空格隔开的整数 \(n,m\)。
第 \(2\) 行至第 \(n+1\) 行,每行 \(m\) 个用单个空格隔开的整数,其中第 \(i+1\) 行的 \(m\) 个 数依次为 \(a_{i,1},a_{i,2},...,a_{i,m}\)。
输出格式
仅一行一个整数,表示所求方案数对 \(998,244,353\) 取模的结果。
输入输出样例
输入 #1
2 3
1 0 1
0 1 1
输出 #1
3
输入 #2
3 3
1 2 3
4 5 0
6 0 0
输出 #2
190
输入 #3
5 5
1 0 0 1 1
0 1 0 1 0
1 1 1 1 0
1 0 1 0 1
0 1 1 0 1
输出 #3
742
数据范围
对于所有测试点,保证 \(1≤n≤100,1≤m≤2000\),\(0 \leq a_{i,j} \lt 998,244,353\)。
思路分析
题目很磨叽,其实就是在一个矩阵中,每一行只能选不超过一个数,每一列所选的不可以超过所有选的个数的一半,满足条件的所有的方案数- 这看上去就很 \(DP\),然而如果直接通过转移就得到答案的话貌似根本没法入手
- 这道题提供了合法方案的限制,如果没有限制就相对好处理一点,所以我们可以用一下简单的容斥,即合法的方案数=所有的方案数-不合法的方案数
- 对于总方案数,我们显然要从行先入手,设 \(g[i][j]\) 为前 \(i\) 行选了\(j\) 个点的方案,\(sum[i]\) 为第 \(i\) 行所有点的和,通过第 \(i\) 行选还是不选就可以得到转移方程:
\[g[i][j] = g[i-1][j]+sum[i]*g[i-1][j-1]
\]
- 对于不合法的方案数相对难搞一些。因为行而不合法的大可不必去管,直接强制让其合法就好了。关键在于列这里,而如果因为列出现了不合法,那么导致方案不合法的列只有一列,也就是这一列的数量比其他列的加起来还要多。所以据此列出状态,首先枚举每一列 \(col\),然后设 \(f[i][j]\) 为前 \(i\) 行中 当前列选的比其他所有列选的多 \(j\) 个的方案数,转移分为三种:1.第 \(i\) 行不选、2.选且选了 \(col\) 列、3.选但选了其他列,则也很容易得到转移方程:
\[f[i][j] = f[i-1][j]+a[i][col]*f[i-1][j-1]+(sum[i]-a[i][col])*f[i-1][j+1]
\]
- 最后用总方案减去不合法的就是答案了
Code
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#define R register
#define N 110
#define M 2020
using namespace std;
inline int read(){
int x = 0,f = 1;
char ch = getchar();
while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
const int mod = 998244353;
int n,m,a[N][M];
long long sum[N],g[N][N],f[N][N<<1],ans;
int main(){
n = read(),m = read();
for(R int i = 1;i <= n;i++){
for(R int j = 1;j <= m;j++)a[i][j] = read(),sum[i]=(a[i][j]+sum[i])%mod;
}
g[0][0] = 1;
for(R int i = 1;i <= n;i++){
for(R int j = 0;j <= i;j++){
if(j>0)g[i][j] = (g[i-1][j]+sum[i]*g[i-1][j-1]%mod)%mod;
else g[i][j] = g[i-1][j];
}
}
for(R int i = 1;i <= n;i++)ans = (ans+g[n][i])%mod;
for(R int col = 1;col <= m;++col){//每一列都要单独处理
memset(f,0,sizeof(f));
f[0][n] = 1;
for(R int i = 1;i <= n;++i){
for(R int j = n-i;j <= n+i;++j){
f[i][j] = (f[i-1][j]+a[i][col]*f[i-1][j-1]%mod+(sum[i]-a[i][col])*f[i-1][j+1]%mod)%mod;
}
}
for(R int i = 1;i <= n;i++)ans = (ans-f[n][n+i]+mod)%mod;
}
printf("%lld\n",ans);
return 0;
}