【CSP-S 2019】D2T1 Emiya 家今天的饭
Description
Solution
算法1 32pts
爆搜,复杂度\(O((m+1)^n)\)
算法2 84pts
裸的dp,复杂度\(O(n^3m)\)
首先有一个显然的性质要知道:
最多只有一种主要食材出现在超过一半的主要食材里。
接下来考虑如果只有前两个限制条件的情况,那么答案就是
其中\(sum_i = \sum \limits_{j=1}^m a_{i,j}\),\(+1\)是因为对于每一行只有选一道菜或者不选这些选择,\(-1\)是因为要去除一道菜都不选的情况。
对于第3个限制条件,发现直接做不太好做,考虑容斥,即用总方案数,也就是上面的式子,减去不合法的方案数。
由最开始的那个性质可以得到一个做法:
枚举不合法的那一种主要食材,然后进行\(dp\)。发现我们并不需要知道每一种主要食材具体用在了多少道菜上,只需要知道当前枚举的食材用在了多少道菜,其它的并不影响方案。那么设\(f_{i,j,k}\)表示前\(i\)中烹饪方式,选了\(j\)道菜,其中\(k\)道的主要食材是枚举的不合法食材。转移分三种情况:令\(s\)表示当前枚举的不合法食材,
-
不在这一种烹饪方式中进行选择:\(f_{i,j,k}=f_{i-1,j,k}\)
-
在这种烹饪方式中选择了合法的食材:\(f_{i,j,k}=(sum_i-a_{i,s}) \times f_{i,j-1,k}\)
-
在这种烹饪方式中选择了不合法的食材:\(f_{i,j,k}=a_{i,s}\times f_{i,j-1,k-1}\)
那么不合法的方案数就是
code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
const int _ = 100 + 10;
const int __ = 2000 + 10;
int n, m, A[_][__];
ll sum[_], f[_][_ << 1], tmp, ans = 1;
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
scanf("%d", &A[i][j]);
sum[i] = (sum[i] + A[i][j]) % mod;
}
ans = ans * (sum[i] + 1) % mod;
}
ans = (ans - 1 + mod) % mod;
for (int k = 1; k <= m; ++k) {
memset(f, 0, sizeof(f));
f[0][n] = 1;
for (int i = 1; i <= n; ++i) {
for (int j = -i + n; j <= i + n; ++j) {
f[i][j] = (f[i - 1][j] + f[i - 1][j - 1] * A[i][k] % mod + f[i - 1][j + 1] * (sum[i] - A[i][k]) % mod) % mod;
if (i == n && j > n) tmp = (tmp + f[i][j]) % mod;
}
}
}
ans = (ans - tmp + mod) % mod;
printf("%lld\n", ans);
return 0;
}
算法三 100pts
考虑如何对算法二的\(dp\)进行优化,减少不必要的状态。对限制三进行转化,限制三即为
发现并不需要关心使用了食材的菜的具体数量,只需要关心合法与不合法的菜的差值即可,即这个差值与原来差值相同的状态的集合是对应的,那么我们就可以以此为状态进行dp,转移与上面是类似的。
唯一要注意的一点是可能出现负数,要加上一个偏移量\(n\)
-
不在这一种烹饪方式中进行选择:\(f_{i,j}=f_{i-1,j}\)
-
在这种烹饪方式中选择了合法的食材:\(f_{i,j}=(sum_i-a_{i,s}) \times f_{i,j+1}\)
-
在这种烹饪方式中选择了不合法的食材:\(f_{i,j}=a_{i,s}\times f_{i,j-1}\)
code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 998244353;
const int _ = 100 + 10;
const int __ = 2000 + 10;
int n, m, A[_][__];
ll sum[_], f[_][_ << 1], tmp, ans = 1;
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j) {
scanf("%d", &A[i][j]);
sum[i] = (sum[i] + A[i][j]) % mod;
}
ans = ans * (sum[i] + 1) % mod;
}
ans = (ans - 1 + mod) % mod;
for (int k = 1; k <= m; ++k) {
memset(f, 0, sizeof(f));
f[0][n] = 1;
for (int i = 1; i <= n; ++i) {
for (int j = -i + n; j <= i + n; ++j) {
f[i][j] = (f[i - 1][j] + f[i - 1][j - 1] * A[i][k] % mod + f[i - 1][j + 1] * (sum[i] - A[i][k]) % mod) % mod;
if (i == n && j > n) tmp = (tmp + f[i][j]) % mod;
}
}
}
ans = (ans - tmp + mod) % mod;
printf("%lld\n", ans);
return 0;
}