P5664 [CSP-S2019] Emiya 家今天的饭 题解
题意就是每行只能至多选一个格子,每列选的格子数不能超过一半,所有可行方案的选的格子内的数的乘积之和。
注意到任意一列选不超过一半的格子不方便计算,但是至多只能有一列能超过一半,因此考虑容斥处理,总方案数减去不合法方案数,下记 \(s_i=\sum_{j=1}^ma_{i,j}\)。
总方案数显然,\(\prod_{i=1}^n(s_i+1)-1\),累乘中加一是因为这一行可以不选,最后减一是要删除全部不选的情况。
对于不合法方案数,考虑枚举不合法的那一列 \(p\),此时注意到我们只关心在 \(p\) 这一列选了多少个,因为非 \(p\) 的列对于计算方案数来讲是完全等价的(从 \(s_i-a_{i,p}\) 里面任选一个),因此设 \(f_{i,j,k}\) 表示前 \(i\) 行中,第 \(p\) 列选了 \(j\) 个,剩下的列选了 \(k\) 个的方案数,枚举这一行选哪一个或者是不选,则有转移方程:
\[f_{i,j,k}=f_{i-1,j,k}+f_{i-1,j-1,k}\times a_{i,p}+f_{i-1,j,k-1}\times(s_i-a_{i,p})
\]
初值 \(f_{0,0,0}=1\),最后不合法方案数为 \(\sum_{j>k}f_{n,j,k}\)。
注意到这个转移 \(O(n^3m)\) 不大行,考虑优化,此时注意到答案式中我们只关心 \(j,k\) 的相对大小而并不关心选了几个,因此更改状态,设 \(f_{i,j}\) 表示前 \(i\) 行中,第 \(p\) 列比别的列多少选了 \(j\) 个的方案数,则有转移方程:
\[f_{i,j}=f_{i-1,j}+f_{i-1,j-1}\times a_{i,p}+f_{i-1,j+1}\times(s_i-a_{i,p})
\]
初值 \(f_{0,0}=1\),最后不合法方案数为 \(\sum_{j>0}f_{n,j}\)。
写代码时注意开 long long,随时取模,以及因为有负下标所以要将所有 \(j\) 加上 \(n+1\) 保证所有下标都非负。
Code:
/*
========= Plozia =========
Author:Plozia
Problem:P5664 [CSP-S2019] Emiya 家今天的饭
Date:2022/10/17
========= Plozia =========
*/
#include <bits/stdc++.h>
typedef long long LL;
const int MAXN = 100 + 5, MAXM = 2000 + 5;
const LL P = 998244353;
int n, m;
LL a[MAXN][MAXM], f[MAXN][MAXN << 1], sum[MAXN], ans;
int Read()
{
int sum = 0, fh = 1; char ch = getchar();
for (; ch < '0' || ch > '9'; ch = getchar()) fh -= (ch == '-') << 1;
for (; ch >= '0' && ch <= '9'; ch = getchar()) sum = (sum << 3) + (sum << 1) + (ch ^ 48);
return sum * fh;
}
int main()
{
n = Read(), m = Read(); for (int i = 1; i <= n; ++i) for (int j = 1; j <= m; ++j) sum[i] = (sum[i] + (a[i][j] = Read())) % P;
for (int _ = 1; _ <= m; ++_)
{
for (int i = 1; i <= n; ++i) for (int j = 0; j <= 2 * n + 1; ++j) f[i][j] = 0;
f[0][n + 1] = 1;
for (int i = 1; i <= n; ++i)
for (int j = -i; j <= i; ++j)
f[i][j + n + 1] = (f[i - 1][j + n + 1] + f[i - 1][j - 1 + n + 1] * a[i][_] % P + f[i - 1][j + 1 + n + 1] * (((sum[i] - a[i][_]) % P + P) % P) % P) % P;
for (int j = 1; j <= n; ++j) ans = (ans + f[n][j + n + 1]) % P;
}
LL s = 1; for (int i = 1; i <= n; ++i) s = s * (sum[i] + 1) % P; s = (s - 1 + P) % P;
printf("%lld\n", ((s - ans) % P + P) % P); return 0;
}