「CSP-S 2019」Emiya 家今天的饭
知识点: DP,容斥
赛场上硬刚64pts 2.5h,最后样例没过。
没时间写爆搜,交了一份自己也不知道在干什么的代码,获得了 8pts 的好成绩。
留下了极大的心理阴影。
消除恐惧的最好方法,就是直面恐惧。
加油,奥利给!
简述
给定一 \(n\times m\) 的矩阵 \(a\)。
从矩阵中取出一个大小为 \(k\) 的集合,满足下列要求:
- 非空。
- 每一行只能选择一个元素。
- 属于同一列的元素数 \(\le \left\lfloor\frac{k}{2}\right\rfloor\)。
一个集合的价值定义为所有元素的乘积。
求所有满足条件的 集合的权值 之和,答案对 \(998244353\) 取模。
\(1\le n\le 100, 1\le m \le 2000, 0\le a_{i,j} \le 998244353\)。
1S,256MB。
分析
算法一
\(n \le 10, m \le 3\)。
我会暴力!
爆搜,枚举所有满足条件的集合,计算权值并累加。
复杂度\(O(m^n)\),期望得分 32 pts。
算法二
\(m\le 2\) 。
我会更优雅的暴力!
发现枚举到第 \(i\) 行时,影响第 \(i\) 行选择的,只有前 \(i-1\) 行中,选择的元素个数,和每一列选择的元素个数。
考虑 DP。
设 \(f_{i,k,n_1,n_2}\) 表示,前 \(i\) 行中,选择了 \(k\) 个元素,第一列元素选择了 \(n_1\) 个,第二列元素选择了 \(n_2\) 个的权值和。
显然有 \(n_2 = k-n_1\)。
转移类似背包,当枚举到第 \(i\) 行时,显然有:
统计答案时,有:
需要枚举 行数 \(i\),列数 \(j\),元素个数 \(k,n_1\)。
复杂度 \(O(n^{m+1}\times m)\),结合算法一数据分治,期望得分 48 pts。
算法三
\(m\le 3\) 。
在算法二的基础上拓展一维。
复杂度 \(O(n^{m+1}\times m)\),结合算法二数据分治,期望得分 64 pts。
算法四
考虑调整状态。
容易想到,合法集合权值和 = 总权值和 - 不合法集合权值和。
对于总权值和:
每一行可以选择 1 个元素,或者不选。
选择 1 个元素时,其对乘积的贡献为 \(a_{i,j}\),不选时贡献为 1。
最后减去空集的情况(其贡献为 1),则总权值和为:
对于不合法集合权值和:
由条件 3 可知,不合法集合中,只存在一个不合法列,使集合中位于该列的元素数 \(> \left\lfloor\frac{k}{2}\right\rfloor\)。
具有可枚举性,考虑通过枚举固定这一列。判断某一集合是否合法,只需已知 位于该列的元素数,和位于其他列的元素数,求差值即可。
令枚举固定的一列为 \(now\),设计状态:\(f_{i,j,k}\) 表示:前 \(i\) 行中,\(now\) 列中 选择了 \(j\) 个,其他列中选择了 \(k\) 个的权值和。
令 \(s_{i,now} = \left(\sum\limits_{j=1}^{m}a_{i,j}\right) - a_{i,now}\),表示一行中非 \(now\) 列的元素之和。
转移时枚举选到第 \(i\) 行的选择情况,考虑第 \(i\) 行是否选择第 \(now\) 列,则有显然的转移方程:
复杂度 \(O(n^3m)\),期望得分 84 pts。
算法五
发现在算法四中,判断集合是否合法,仅需知道被枚举列和 其他列元素数的差值即可。
分别记录两种元素的个数是没有必要的。
考虑直接记录差值。
算法四转移方程中 \(f_{i-1,j-1,k}\rightarrow f_{i-1,j,k}\) 和 \(f_{i-1,j,k-1}\rightarrow f_{i-1,j,k}\),可看做差值的 增大 / 减小。
同样令固定的一列为 \(now\),可设计新状态,令 \(f_{i,j}\) 表示,前 \(i\) 行中, \(now\) 列中 选择个数 与 其他列个数差值为 \(j\) 时的权值和。
令 \(s_{i,now} = \left(\sum\limits_{j=1}^{m}a_{i,j}\right) - a_{i,now}\),表示一行中非 \(now\) 行的元素之和。
转移时枚举选到第 \(i\) 行的选择情况,考虑第 \(i\) 行是否选择第 \(now\) 列,则有:
复杂度 \(O(n^2m)\),期望得分 100 pts。
实现
算法三
//知识点:DP
/*
By:Luckyblock
*/
#include <cstdio>
#include <ctype.h>
#include <cstring>
#include <algorithm>
const int kN = 42;
const int mod = 998244353;
//=============================================================
int n, m, a[kN][kN];
int ans, f[2][kN][kN / 2][kN / 2][kN / 2];
//=============================================================
inline int read() {
int f = 1, w = 0; char ch = getchar();
for (; !isdigit(ch); ch = getchar()) if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
//=============================================================
int main() {
freopen("meal.in", "r", stdin);
freopen("meal.out", "w", stdout);
n = read(), m = read();
for (int i = 1; i <= n; ++ i) {
for (int j = 1; j <= m; ++ j) {
a[i][j] = read();
}
}
f[0][0][0][0][0] = 1;
for (int i = 1; i <= n; ++ i) { //到达第 i 行
for (int j = 1; j <= m; ++ j) { //枚举第 j 列
if (! a[i][j]) continue ;
int d[4] = {0}; d[j] = 1;
for (int k = 1; k <= i; ++ k) { //将要选择 k 种菜
for (int n1 = d[1]; n1 <= k; ++ n1) {
for (int n2 = d[2]; n2 <= k; ++ n2) {
int n3 = k - n2 - n1;
if (n3 < d[3]) continue ;
f[1][k][n1][n2][n3] += 1ll * f[0][k - 1][n1 - d[1]][n2 - d[2]][n3 - d[3]] * a[i][j] % mod;
f[1][k][n1][n2][n3] %= mod;
}
}
}
}
for (int k = 0; k <= i; ++ k) {
for (int n1 = 0; n1 <= k; ++ n1) {
for (int n2 = 0; n2 <= k; ++ n2) {
int n3 = k - n2 - n1;
if (n3 < 0) continue ;
f[0][k][n1][n2][n3] += f[1][k][n1][n2][n3];
f[0][k][n1][n2][n3] %= mod;
}
}
}
memset(f[1], 0, sizeof(f[1]));
}
for (int k = 2; k <= n; ++ k) {
for (int n1 = 0; n1 <= k / 2; ++ n1) {
for (int n2 = 0; n2 <= std::min(k / 2, k - n1); ++ n2) {
int n3 = k - n2 - n1;
if (n3 > k / 2 || n3 < 0) continue ;
ans += f[0][k][n1][n2][n3];
ans %= mod;
}
}
}
printf("%d\n", ans);
return 0;
}
算法四
//知识点:DP
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kN = 110;
const int mod = 998244353;
//=============================================================
int n, m, ans = 1, minus, a[kN][20 * kN], sum[kN];
int f[kN][kN][kN];
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
//=============================================================
int main() {
freopen("meal.in", "r", stdin);
freopen("meal.out", "w", stdout);
n = read(), m = read();
for (int i = 1; i <= n; ++ i) {
for (int j = 1; j <= m; ++ j) {
a[i][j] = read();
sum[i] = (sum[i] + a[i][j]) % mod;
}
}
for (int i = 1; i <= n; ++ i) {
ans = 1ll * ans * (sum[i] + 1) % mod;
}
ans = (ans - 1 + mod) % mod;
for (int now = 1; now <= m; ++ now) {
memset(f, 0, sizeof(f));
f[0][0][0] = 1;
for (int i = 1; i <= n; ++ i) {
for (int j = 0; j <= i; ++ j) {
for (int k = 0; k <= i - j; ++ k) {
f[i][j][k] = f[i - 1][j][k];
if (j > 0) f[i][j][k] += 1ll * f[i - 1][j - 1][k] * a[i][now] % mod;
f[i][j][k] %= mod;
if (k > 0) f[i][j][k] += 1ll * (sum[i] - a[i][now] + mod) % mod * f[i - 1][j][k - 1] % mod;
f[i][j][k] %= mod;
}
}
}
for (int j = 1; j <= n; ++ j) {
for (int k = 0; k <= n - j; ++ k) {
if (j > k) minus = (minus + f[n][j][k]) % mod;
}
}
}
printf("%d", ((ans - minus) % mod + mod) % mod);
return 0;
}
算法五
//知识点:DP
/*
By:Luckyblock
*/
#include <algorithm>
#include <cctype>
#include <cstdio>
#include <cstring>
#define LL long long
const int kN = 210;
const int mod = 998244353;
//=============================================================
int n, m, ans = 1, minus, a[kN][kN * 10], sum[kN];
int f[kN][kN];
//=============================================================
inline int read() {
int f = 1, w = 0;
char ch = getchar();
for (; !isdigit(ch); ch = getchar())
if (ch == '-') f = -1;
for (; isdigit(ch); ch = getchar()) w = (w << 3) + (w << 1) + (ch ^ '0');
return f * w;
}
//=============================================================
int main() {
freopen("meal.in", "r", stdin);
freopen("meal.out", "w", stdout);
n = read(), m = read();
for (int i = 1; i <= n; ++ i) {
for (int j = 1; j <= m; ++ j) {
a[i][j] = read();
sum[i] = (sum[i] + a[i][j]) % mod;
}
}
for (int i = 1; i <= n; ++ i) {
ans = 1ll * ans * (sum[i] + 1) % mod;
}
ans = (ans - 1 + mod) % mod;
for (int now = 1; now <= m; ++ now) {
memset(f, 0, sizeof(f));
f[0][n] = 1; // 0 + 偏移量n
for (int i = 1; i <= n; ++ i) {
for (int j = n - i; j <= n + i; ++ j) {
f[i][j] = f[i - 1][j];
f[i][j] += 1ll * f[i - 1][j - 1] * a[i][now] % mod;
f[i][j] %= mod;
f[i][j] += 1ll * f[i - 1][j + 1] * (sum[i] - a[i][now] + mod) % mod;
f[i][j] %= mod;
}
}
for (int j = 1; j <= n; ++ j) {
minus = (minus + f[n][j + n]) % mod;
}
}
printf("%d", ((ans - minus) % mod + mod) % mod);
return 0;
}