Max 题解
一、题目:
二、思路:
非常适合我学习的一道DP题!😆
首先说一下我在考场的思路,拿到了45分。
考虑设 \(f[i,S]\) 表示当前操作了 \(i\) 步,每个元素值的情况为 \(S\) 的概率。这里的 \(S\) 可以用 \(c\times m+1\) 进制数压起来。转移即可。
那么正解比我的做法高明在什么地方呢?我们发现这道题操作的顺序其实是无所谓的。所以我们可以考虑将 DP 阶段设成元素。也就是说,解决完一个元素再解决下一个元素。
具体来说,首先求出辅助数组 \(f\)。\(f[i,S,j]\) 表示第 \(i\) 个元素再经过操作集合 \(S\) 后,值为 \(j\) 的概率。其中 \(S\) 是一个二进制数。第 \(k\) 位是 1 表示第 \(k\) 步操作选中了 \(A_i\)。
然后我们来求解 \(dp\) 数组。\(dp[i,S,j]\) 表示前 \(i\) 个元素在经过操作集合 \(S\) 后,最大值是 \(j\) 的概率。同样地,\(S\) 还是一个二进制数,第 \(k\) 位是1表示第 \(k\) 步操作选中了前 \(i\) 个元素。在这里我们采用刷表法比较容易实现。
\[dp\left[i+1,S\cup T, \max\{j, j'\}\right]+=dp[i,S,j]\times f[i+1,T,j']
\]
我们要保证 \(T\) 和 \(S\) 没有交集。
那么怎样才能想出来这种做法呢?首先我们可以看到数据范围,\(m\) 比较小,\(n\) 相对来说比较大。如果要状压的话肯定状压 \(m\) 比 \(n\) 要合适。那么像这种对操作进行状压的题目真的非常罕见,一般都是对元素进行状压。所以要平时积累这种思维。
三、代码:
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#define FILEIN(s) freopen(s".in", "r", stdin);
#define FILEOUT(s) freopen(s".out", "w", stdout)
#define mem(s, v) memset(s, v, sizeof s)
inline int read(void) {
int x = 0, f = 1; char ch = getchar();
while (ch < '0' || ch > '9') { if (ch == '-') f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
return f * x;
}
const int maxm = 12, maxn = 42, maxc = 5, maxv = 32, mod = 1e9 + 7;
int n, m, c;
long long P[maxm][maxn][maxc];
long long f[maxn][1 << maxm][maxv];
long long dp[maxn][1 << maxm][maxv];
#define call(s, i) ((s >> i) & 1)
int main() {
FILEIN("max"); FILEOUT("max");
n = read(); m = read(); c = read();
for (int i = 0; i < m; ++ i) {
for (int j = 1; j <= n; ++ j) {
for (int k = 0; k <= c; ++ k) {
P[i][j][k] = read();
}
}
}
for (int now = 1; now <= n; ++ now) {
f[now][0][0] = 1;
for (int s = 1; s < (1 << m); ++ s) // the state of operation
for (int j = 0; j <= c * m; ++ j) // the present value
for (int i = 0; i < m; ++ i)
if (call(s, i)) {
int t = s - (1 << i);
for (int k = 0; k <= min(c, j); ++ k)
(f[now][s][j] += f[now][t][j - k] * P[i][now][k] % mod) %= mod;
break;
}
}
dp[0][0][0] = 1;
for (int i = 0; i < n; ++ i)
for (int s = 0; s < (1 << m); ++ s)
for (int j = 0; j <= c * m; ++ j) { // the last maximum value
if (!dp[i][s][j]) continue;
int ss = (~s) & ((1 << m) - 1);
for (int t = ss; ; t = (t - 1) & ss) {
for (int _j = 0; _j <= c * m; ++ _j) // the present maximum value
(dp[i + 1][s | t][max(_j, j)] += dp[i][s][j] * f[i + 1][t][_j] % mod) %= mod;
if (!t) break;
}
}
long long ans = 0;
for (int j = 0; j <= c * m; ++ j)
(ans += j * dp[n][(1 << m) - 1][j] % mod) %= mod;
printf("%lld\n", ans);
return 0;
}